In [1]:
import os
os.chdir("../")

In [2]:
import wandb
import torch
import random
import time
import numpy as np
import cv2
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import transforms
from torch import optim
from tqdm.notebook import tqdm
from models.mobilenetv2 import mobilenetv2
from data.affwild2_dataset import AffWild2ExprDataset
from torch.utils.data import DataLoader, Subset
from typing import Any
from torchsummary import summary

In [11]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Utility Functions

In [12]:
def random_seed(seed):
    """Set seed"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)

def store_params(content, name):
    f = open(f'params/{name}.pkl','wb')
    pickle.dump(content, f)
    f.close()

def load_params(name):
    fl = open(f'params/{name}.pkl', "rb")
    loaded = pickle.load(fl)
    return loaded

def store_model(model, name):
    torch.save(model.state_dict(), f'./trained_models/{name}.pth')
                                

# Data Preparation

In [13]:
project_name = 'moody_much'
cores = 15
random_seed(8)
batch_size = 1

In [14]:
train_dataset = AffWild2ExprDataset(train=True, skip=2, remove_mismatch=True)
valid_dataset = AffWild2ExprDataset(train=False, skip=2, remove_mismatch=True)

In [15]:
len(train_dataset), len(valid_dataset)

(111, 26)

In [10]:
total_valid_num = len(valid_dataset)
total_train_num = len(train_dataset)
valid_num = int(0.5 * total_valid_num)

valid_mask = list(range(valid_num))
test_mask = list(range(valid_num, total_valid_num))

valid_loader = DataLoader(Subset(valid_dataset, valid_mask), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(Subset(valid_dataset, test_mask), batch_size=batch_size, shuffle=True)

small_train_mask = random.sample(range(total_train_num), 20)
medium_train_mask = random.sample(range(total_train_num), 50)
small_valid_mask = random.sample(range(total_valid_num), 2)

small_train_loader = DataLoader(Subset(train_dataset, list(small_train_mask)), batch_size=batch_size, 
                                shuffle=True, num_workers=2)
small_valid_loader = DataLoader(Subset(valid_dataset, list(small_valid_mask)), batch_size=batch_size, 
                                shuffle=True, num_workers=2)

medium_loader = DataLoader(Subset(train_dataset, list(medium_train_mask)), batch_size=batch_size, shuffle=True)

# Training

In [9]:
hyperparameters = {
    'epochs': 20,
    'pretrained': True,
    'batch_size': 1,
    'learning_rate': 0.001,
    'optimizer': 'adam',
    'weight_decay': 4e-5,
}

In [16]:
def evaluate(model: nn.Module, data_loader: Any, device: torch.device, comment: str = ""):
    
    model.eval()
    mode.to(device)
    
    total_sample = 0
    correct_samples = 0
    total_loss = 0
    loss_history = []

    batch = 128
    count = 0
    with torch.no_grad():
        for (frames, labels) in tqdm(data_loader):
            
            num_frames = frames.shape[1]
            face_frames = frames.squeeze()
            for i in range(num_frames // batch):
                
                l = batch * i
                r = min(batch * (i + 1), num_frames)
                
                face_batch = face_frames[l:r]
                label = labels[:,l:r].squeeze()
                
                face_batch = face_batch.to(device)
                label = label.to(device)
                
                out = model(face_batch)
                out_prob = F.log_softmax(out, dim=1)
                loss = F.nll_loss(out_prob, label, reduction='sum')
                _, pred = torch.max(out_prob, dim=1)

                total_loss += loss.item()
                correct_samples += pred.eq(label).sum()
                count += 1
                total_samples = face_batch.shape[0]
                
    avg_loss = total_loss / count
    wandb.log({'valid_loss': avg_loss})
    accuracy = 100.0 * correct_samples / total_samples
    return accuracy

In [12]:
def train(model, optimizer, epochs, data_loader, test_loader, device):
    
    # wandb.watch(model, log="all", log_freq=10)
    
    full_start = time.time()
    for epoch in tqdm(range(epochs)):
        
        model.train()
        model.to(device)
        print(f"Starting Epoch {epoch}")
        
        total_loss = 0
        epoch_time = time.time()
        num_batches = 0
        
        batch = 128
        count = 0

        for (frames, labels) in tqdm(data_loader):
            
            num_frames = frames.shape[1]
            face_frames = frames.squeeze()
            for i in range(num_frames // batch):
                
                l = batch * i
                r = min(batch * (i + 1), num_frames)
                
                face_batch = face_frames[l:r]
                label = labels[:,l:r].squeeze()
                
                face_batch = face_batch.to(device)
                label = label.to(device)
                
                optimizer.zero_grad()
                
                oout = model(face_batch)
                out_prob = F.log_softmax(out, dim=1)
                loss = F.nll_loss(out_prob, label, reduction='sum')
                _, pred = torch.max(out_prob, dim=1)
                
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                correct_samples += pred.eq(label).sum()
                count += 1
                total_samples = face_batch.shape[0]

                wandb.log({'batch_loss': loss.item()})
        
         
        print(f"Finished Epoch {epoch}")
        valid_accuracy = evaluate(model, test_loader, device)
        train_accuracy = evaluate(model, train_loader, device)
        
        print(f"Validation Accuracy: ", valid_accuracy)
        print(f"Training Accuracy: ", train_accuracy)
        
        wandb.log({
            'loss': total_loss / num_batches,
            'valid_accuracy': valid_accuracy,
            'train_accuracy': train_accuracy,
            'epoch_time_minutes': (time.time() - epoch_time) / 60
        })
        
    wandb.log({'full_run_time_minutes': (time.time() - full_start) / 60})
        

In [13]:
def train_model(hyperparameters, model=None, model_path=None):
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    
    with wandb.init(project=project_name, config=hyperparameters):
       
        config = wandb.config
        
        if model is None:
            print("Creating a New Model.")
            model = mobilenetv2()
            model.load_state_dict(torch.load('weights/mobilenetv2_128x128-fd66a69d.pth'))
            model.classifier = nn.Linear(model.classifier.in_features, 7)
            
        optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
           
        train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=cores)    
        
        train(model, optimizer, config.epochs, small_train_loader, small_valid_loader, device)

        test_accuracy = evaluate(model, test_loader, device)
        
        wandb.log({'test_accuracy': test_accuracy})
    
    return model, test_loss

In [None]:
model, test_loss = train_model(hyperparameters)

cuda:0


[34m[1mwandb[0m: Currently logged in as: [33mnazirnayal98[0m (use `wandb login --relogin` to force relogin)


Creating a New Model.


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

Starting Epoch 0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detect

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detect

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detect

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detect

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detect

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detect

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detect

In [None]:
summary(model.to(torch.device('cuda:0')), (3, 224, 224))