In [1]:
import datetime
import sys
from sys import path
from os.path import join as join_paths

import numpy as np
import pandas as pd
import torch
import torch.optim as optim
from numpy.random import randint
from skimage import io
from tqdm.notebook import tqdm
from torch.nn.utils import clip_grad_norm
from torch.utils.data import DataLoader, Dataset

path.append("./")
import net_sphere
from calculateEvaluationCCC import calculateCCC

# Define parameters
use_cuda = torch.cuda.is_available()

lr = 0.01
bs = 32
n_epoch = 30
lr_steps = [8,16,24]

gd = 20 # clip gradient
eval_freq = 3
print_freq = 20
num_worker = 4
num_seg = 16
flag_biLSTM = True

classnum = 7

train_list_path = './support_tables/train_list_lstm.txt'
val_list_path = './support_tables/validation_list_lstm.txt'
model_path = './model/sphere20a_20171020.pth'
train_data_path: str = "/Users/leonardoalchieri/Datasets/OMGEmotionChallenge/Train_Set/trimmed_faces"
validation_data_path: str = "/Users/leonardoalchieri/Datasets/OMGEmotionChallenge/Validation_Set/trimmed_faces"

In [2]:
sphereface = getattr(net_sphere,'sphere20a')()
sphereface.load_state_dict(torch.load(model_path))
sphereface.feature = True # remove the last fc layer because we need to use LSTM first

class Net(torch.nn.Module):
    def __init__(self, sphereface):
        super(Net, self).__init__()
        self.sphereface = sphereface
        self.linear = torch.nn.Linear(512,2)
        self.tanh = torch.nn.Tanh()
        self.avgPool = torch.nn.AvgPool2d((num_seg,1), stride=1)
        self.LSTM = torch.nn.LSTM(512, 512, 1, batch_first = True, dropout=0.2, bidirectional=flag_biLSTM)  # Input dim, hidden dim, num_layer
        for name, param in self.LSTM.named_parameters():
            if 'bias' in name:
                torch.nn.init.constant(param, 0.0)
            elif 'weight' in name:
                torch.nn.init.orthogonal(param)
        
    def sequentialLSTM(self, input, hidden=None):

        input_lstm = input.view([-1,num_seg, input.shape[1]])
        batch_size = input_lstm.shape[0]
        feature_size = input_lstm.shape[2]

        self.LSTM.flatten_parameters()
            
        output_lstm, hidden = self.LSTM(input_lstm)
        if flag_biLSTM:
             output_lstm = output_lstm.contiguous().view(batch_size, output_lstm.size(1), 2, -1).sum(2).view(batch_size, output_lstm.size(1), -1) 

        # avarage the output of LSTM
        output_lstm = output_lstm.view(batch_size,1,num_seg,-1)
        out = self.avgPool(output_lstm)
        out = out.view(batch_size,-1)
        return out
    
    def forward(self, x):
        x = self.sphereface(x)
        x = self.sequentialLSTM(x)
        x = self.linear(x)
        x = self.tanh(x)
        
        return x

In [3]:
def printoneline(*argv):
    s = ''
    for arg in argv: s += str(arg) + ' '
    s = s[:-1]
    sys.stdout.write('\r'+s)
    sys.stdout.flush()
    
def dt():
    return datetime.datetime.now().strftime('%H:%M:%S')

def save_model(model,filename):
    state = model.state_dict()
    torch.save(state, filename)

In [4]:
def train(train_loader, model, criterion, optimizer, epoch):
    model.train()
    
    train_loss = 0
    correct = 0
    total = 0
    batch_idx = 0
    
    for i, (inputs, targets, _) in enumerate(train_loader):
        
        optimizer.zero_grad()
        
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)

        inputs = torch.autograd.Variable(inputs)
        targets = torch.autograd.Variable(targets)
        
        inputs = inputs.view((-1,3)+inputs.size()[-2:])
        outputs = model(inputs)
        
        
        loss = criterion(outputs,targets)
        
        loss.backward()
        optimizer.step()
        
        #tsn uses clipping gradient
        if gd is not None:
            total_norm = clip_grad_norm(model.parameters(),gd)
            if total_norm > gd:
                print('clippling gradient: {} with coef {}'.format(total_norm, gd/total_norm))
                
        train_loss += loss.data[0]
        
        if i % print_freq == 0:
            printoneline(dt(),'Epoch=%d Loss=%.4f\n'
                % (epoch,train_loss/(batch_idx+1)))
        batch_idx += 1


def validate(val_loader, model, criterion, epoch):
    model.eval()
    
    err_arou = 0.0
    err_vale = 0.0
    
    print('Loading output file')
    txt_result = open('results/val_lstm_%d.csv'%epoch, 'w')
    txt_result.write('video,utterance,arousal,valence\n')
    for (inputs, targets,(vid, utter)) in tqdm(val_loader, 'Validation batch'):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        
        inputs = torch.autograd.Variable(inputs)
        targets = torch.autograd.Variable(targets)
        
        inputs = inputs.view((-1,3)+inputs.size()[-2:])
        print(f'Getting model output')
        outputs = model(inputs)
        
        outputs = outputs.data.cpu().numpy()
        targets = targets.data.cpu().numpy()
        
        err_arou += np.sum((outputs[:,0]-targets[:,0])**2)
        err_vale += np.sum((outputs[:,1]-targets[:,1])**2)
        
        print(
            'preparing to write ccc results'
        )
        for i in range(len(vid)):
            out = outputs
            txt_result.write('%s,%s.mp4,%f,%f\n'%(vid[i], utter[i],out[i][0],out[i][1]))
    
    txt_result.close()
    
    arouCCC, valeCCC = calculateCCC('./results/omg_ValidationVideos.csv','results/val_lstm_%d.csv'%epoch)
    return (arouCCC,valeCCC)

class OMGDataset(Dataset):
    """OMG dataset."""

    def __init__(self, txt_file, base_path, transform=None):
        self.base_path = base_path
        self.data = pd.read_csv(txt_file, sep=" ", header=0, index_col=0)
        self.data.dropna(inplace=True, how='any')
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        vid = self.data.iloc[idx,0]
        utter = self.data.iloc[idx,1]
        img_list = self.data.iloc[idx,-1]
        img_list = img_list.split(',')[:-1]
        # img_list = [int(img) for img in img_list]
        print('Prepared imaged ids')
        
        num_frames = len(img_list)
        # inspired by TSN's pytorch code
        # FIXME: num_seg is hardcoded
        average_duration = num_frames // num_seg
        if num_frames>num_seg:
            offsets = np.multiply(list(range(num_seg)), average_duration) + randint(average_duration, size=num_seg)
        else:
            tick = num_frames / float(num_seg)
            offsets = np.array([int(tick / 2.0 + tick * x) for x in range(num_seg)])
        print('Created offests')

        final_list = [img_list[i] for i in offsets]
        
        # stack images within a video in the depth dimension
        print('Stacking images')
        for i,ind in enumerate(final_list):
            image = io.imread(join_paths(self.base_path,'%s/%s/%s.png'%(vid,utter,ind))).astype(np.float32)
            image = torch.from_numpy(((image - 127.5)/128).transpose(2,0,1))
            if i==0:
                images = image
            else:
                images = torch.cat((images,image), 0)
        print('Stacked images')
        
        label = torch.from_numpy(np.array([self.data.iloc[idx,2], self.data.iloc[idx,3]]).astype(np.float32))
        print('Prepared labels')

        if self.transform:
            image = self.transform(image)
        return (images, label, (vid,utter))
    
    

model = Net(sphereface)

if use_cuda:
    model.cuda()

criterion = torch.nn.MSELoss()

print('Preparing data loaders')
train_loader = DataLoader(OMGDataset(train_list_path,train_data_path), 
                          batch_size=bs, shuffle=True, num_workers=1)
val_loader = DataLoader(OMGDataset(val_list_path,validation_data_path), 
                        batch_size=bs, shuffle=False, num_workers=1)



optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)    

best_arou_ccc, best_vale_ccc = validate(val_loader, model, criterion,0)

for epoch in tqdm(range(n_epoch), desc='Epoch'):
    if epoch in lr_steps:
        lr *= 0.1
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)    

    train(train_loader, model, criterion, optimizer, epoch)
    
    # evaluate on validation set
    if (epoch+1)%eval_freq == 0 or epoch == n_epoch-1:
        arou_ccc, vale_ccc = validate(val_loader, model, criterion,epoch)
        
        if (arou_ccc+vale_ccc) > (best_arou_ccc + best_vale_ccc):
            best_arou_ccc = arou_ccc
            best_vale_ccc = vale_ccc
            save_model(model,'./pth/model_lstm_{}_{}_{}.pth'.format(epoch, round(arou_ccc,4), round(vale_ccc,4)))
            

  torch.nn.init.orthogonal(param)
  torch.nn.init.constant(param, 0.0)


Preparing data loaders
Loading output file


Validation batch:   0%|          | 0/20 [00:00<?, ?it/s]

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/leonardoalchieri/miniconda3/envs/torch_latest/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/leonardoalchieri/miniconda3/envs/torch_latest/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'OMGDataset' on <module '__main__' (built-in)>


KeyboardInterrupt: 