In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import numpy as np
import os
import shutil
import  opennmt.inputters.record_inputter as inpu
import tensorflow as tf
import warnings
import time
import torch.utils.data as data
from torch import nn
from PIL import Image
import os
import os.path
import random
from sklearn.metrics import accuracy_score, confusion_matrix
from pytorch_i3d import InceptionI3d


In [None]:
warnings.filterwarnings('ignore')

In [None]:
class I(torch.nn.Module):
    def __init__(self):
        super(I, self).__init__()
        
    def forward(self, x):
        return x
    
    def extra_repr(self):
        return 'identity'


class VideoClass(torch.nn.Module):
    def __init__(self):
        super(VideoClass, self).__init__()
        load_model = '/home/alptekin/Desktop/pytorch-i3d-master/models/rgb_imagenet.pt'
        self.model = InceptionI3d()
        self.model.load_state_dict(torch.load(load_model))
        self.model.fc = I()
        self.fc_out = torch.nn.Linear(in_features=1024, out_features=61)

    def get_class(self,x):
        x = self.model.extract_features(x)
        out = self.fc_out(x.squeeze())
        return out 

    
    def get_feature(self,x):
        x = self.model.extract_features(x)
        return x.squeeze()

In [None]:
def make_weights_for_balanced_classes(images, nclasses,power=1,threshold=None, flag = 1):                        
    count = [0] * nclasses                                                      
    for item in images:                                                         
        count[item[1]] += 1
    weight_per_class = np.zeros(nclasses)       
    N = float(sum(count))                                                   
    for i in range(flag,nclasses):
        if count[i] != 0:
            weight_per_class[i] = N/float(count[i])
    weight_per_class = np.array(weight_per_class)**power
    if threshold is not None:
        weight_per_class /= weight_per_class[weight_per_class!=0].min()
        weight_per_class = np.clip(weight_per_class,a_max=threshold,a_min=0)
    if flag:
        weight_per_class[0] = 0.01
    weight = [0] * len(images)
    print(weight_per_class)
    for idx, val in enumerate(images):                                          
        weight[idx] = weight_per_class[val[1]]                                  
    return weight 

In [None]:
class SequenceFilelist(data.Dataset):
    def __init__(self, image_list, transform=None):
        self.imgs = image_list
        self.transform = transform

    def __getitem__(self, index):
        impath, target = self.imgs[index]
        np.random.seed()
        indi = np.random.randint(16)-15
        video_imgs = torch.zeros(3,16,224,224)
        for j,i in enumerate(range(indi,indi+16)):
            aux_index = index + i
            if aux_index < 0:
                aux_index = 0
            if aux_index > (len(self.imgs)-1):
                aux_index = len(self.imgs)-1
                
            aux_impath, _ = self.imgs[aux_index]
            img = self.img_loader(aux_impath)
            
            if self.transform is not None:
                img = self.transform(img)
            video_imgs[:,j,:,:] = img
            
        return video_imgs, target

    def __len__(self):
        return len(self.imgs)
    
    def img_loader(self,path):
        return Image.open(path).convert('RGB')

In [None]:
img_size = 224
trans_train = transforms.Compose([
    torchvision.transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989])
])

trans_test = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989])
])

In [None]:
batch_size = 16
test_batch_size= batch_size * 2

In [None]:
data_list = []
video_list = dict()
f = open('danish_nz_images.txt') 
count = 0
counts = [0] * 61
for l in f.readlines():
    path,label = l.split()
    my_path = path.split('/')
    video_name = my_path[-2]
    my_path = os.path.join(*my_path[-3:])
    label = int(float(label))
    data_list.append((my_path,label))
    if label == 0:
            continue
    if video_name in  video_list:
        video_list[video_name].append((my_path,label))
    else:
        video_list[video_name] = [(my_path,label)]
    counts[label] += 1
f.close()

In [None]:
data_list = []
for key,vd in video_list.items():
    data_list.append(vd)

In [None]:
def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [None]:
random.shuffle(data_list)
data_size = len(data_list)
dev_size = round(data_size * 0.1)
dev_list = data_list[:dev_size]
train_list = data_list[dev_size:]

In [None]:
train_list_ = list()
for t in train_list:
    train_list_.extend(t)
train_list = train_list_

In [None]:
dev_list_ = list()
for t in dev_list:
    dev_list_.extend(t[10:-10])
dev_list = dev_list_

In [None]:
def get_class_dist(l):
    count = np.zeros(61)
    for a in l:
        count[a[1]] += 1
    return count[1:]

In [None]:
t_d = get_class_dist(train_list)
d_d = get_class_dist(dev_list)

In [None]:
plt.plot(t_d/t_d.sum())
plt.plot(d_d/d_d.sum())

In [None]:
train_data = SequenceFilelist(image_list=train_list, transform=trans_train)
dev_data = SequenceFilelist(image_list=dev_list, transform=trans_test)

weights = make_weights_for_balanced_classes(train_list, 61,threshold=120,power=0.5,flag=0)  
weights = torch.DoubleTensor(np.array(weights))
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights)) 

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=sampler,num_workers=8, 
                                           worker_init_fn=worker_init_fn)
dev_loader = torch.utils.data.DataLoader(dev_data, batch_size=test_batch_size,num_workers=8)

In [None]:
device=torch.device("cuda:0")

In [None]:
model = VideoClass()
model.to(device)
model.train()
print('model created')

In [None]:
iter_num = len(train_loader)

In [None]:
crit = torch.nn.CrossEntropyLoss()
lr = 1e-4
optimizer = torch.optim.Adam(lr=lr,params=model.parameters(),weight_decay=2e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,step_size=4,gamma=0.2)
alpha = 0.2
my_step = 39
inner = 4
report_number = 20

In [None]:
for epoch in range(1,40):
    model.train()
    epoch_loss = 0
    epoch_aux_loss = 0
    start = time.time()
    it = iter(train_loader)
    scheduler.step()
    pred_list = list()
    label_list = list()
    #####################
    for i in range(iter_num):
        #######################
        optimizer.zero_grad()
        ########################
        x,y = it.next()
        logits = model.get_class(x.to(device))
        loss = crit(logits,y.to(device))
        preds = logits.argmax(dim=1)
        pred_list.extend(list(preds.detach().cpu().numpy().reshape(-1)))
        label_list.extend(list(y.detach().cpu().numpy().reshape(-1)))
        ########################
        epoch_loss += loss.item()
        ########################
        total_loss =  loss  
        total_loss.backward()
        optimizer.step()
        if (i+1) % report_number == 0:
            ac = accuracy_score(y_pred=pred_list, y_true=label_list)
            print((i+1),epoch_loss / report_number, ac,end='\r')
            pred_list = list()
            label_list = list()
            epoch_loss = 0
            epoch_aux_loss = 0 
        if (i+1) % (iter_num//inner) == 0:
            model.eval()
            pred_list = list()
            label_list = list()
            for ti,(x,y) in enumerate(dev_loader):
                with torch.no_grad():
                    logits = model.get_class(x.to(device))
                preds = logits.argmax(dim=1)
                print(ti/len(dev_loader),end='\r')
                pred_list.extend(list(preds.cpu().numpy().reshape(-1)))
                label_list.extend(list(y.cpu().numpy().reshape(-1)))
            print('Test Accuracy:',accuracy_score(y_pred=pred_list, y_true=label_list),scheduler.get_lr())
            model_path = 'model' + str(my_step) + '.pth'
            torch.save(model,model_path)
            model.train()
            my_step += 1

    print('----------------------'+str(epoch)+'------------------------')
    print('------------------------------------------------')
    print('Loss:',epoch_loss/iter_num,scheduler.get_lr()[0])
    print('Elasped Time:', round(time.time()-start))
    start = time.time()
    print('------------------------------------------------')

In [None]:
model.to(device)
model.eval()
pred_list = list()
label_list = list()
for ti,(x,y) in enumerate(dev_loader):
    with torch.no_grad():
        logits = model.get_class(x.to(device))
    preds = logits.argmax(dim=1)
    print(ti/len(dev_loader),end='\r')
    pred_list.extend(list(preds.cpu().numpy().reshape(-1)))
    label_list.extend(list(y.cpu().numpy().reshape(-1)))
print('Test Accuracy:',accuracy_score(y_pred=pred_list, y_true=label_list))

In [None]:
pred_list = np.array(pred_list) 