In [1]:
import os
os.chdir("../")

In [2]:
import wandb
import torch
import random
import time
import numpy as np
import cv2
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import transforms
from torch import optim
from tqdm.notebook import tqdm
from models.mobilenetv2 import mobilenetv2
from models.blazeface import BlazeFace
from data.affwild2_dataset import AffWild2VADataset
from torch.utils.data import DataLoader, Subset
from typing import Any
from torchsummary import summary

In [3]:
%load_ext autoreload
%autoreload 2

# Utility Functions

In [4]:
def random_seed(seed):
    """Set seed"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)

def store_params(content, name):
    f = open(f'params/{name}.pkl','wb')
    pickle.dump(content, f)
    f.close()

def load_params(name):
    fl = open(f'params/{name}.pkl', "rb")
    loaded = pickle.load(fl)
    return loaded

def store_model(model, name):
    torch.save(model.state_dict(), f'./trained_models/{name}.pth')
                                

# Data Preparation

In [5]:
project_name = 'moody_much'
cores = 12
random_seed(8)
batch_size = 1

In [6]:
train_dataset = AffWild2VADataset(train=True, skip=4, split=0.8)
valid_dataset = AffWild2VADataset(train=False, skip=4, split=0.8)

In [7]:
len(train_dataset), len(valid_dataset)

(201, 51)

In [8]:
total_valid_num = len(valid_dataset)
total_train_num = len(train_dataset)
valid_num = int(0.5 * total_valid_num)

valid_mask = list(range(valid_num))
test_mask = list(range(valid_num, total_valid_num))

valid_loader = DataLoader(Subset(valid_dataset, valid_mask), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(Subset(valid_dataset, test_mask), batch_size=batch_size, shuffle=True)

small_train_mask = random.sample(range(total_train_num), 20)
medium_train_mask = random.sample(range(total_train_num), 50)
small_valid_mask = random.sample(range(total_valid_num), 2)

small_train_loader = DataLoader(Subset(train_dataset, list(small_train_mask)), batch_size=batch_size, 
                                shuffle=True, num_workers=2)
small_valid_loader = DataLoader(Subset(valid_dataset, list(small_valid_mask)), batch_size=batch_size, 
                                shuffle=True, num_workers=2)

medium_loader = DataLoader(Subset(train_dataset, list(medium_train_mask)), batch_size=batch_size, shuffle=True)

# Training

In [9]:
hyperparameters = {
    'epochs': 20,
    'pretrained': True,
    'batch_size': 1,
    'learning_rate': 0.001,
    'optimizer': 'adam',
    'weight_decay': 4e-5,
}

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
blazeface = BlazeFace().to(device)
blazeface.load_weights("weights/blazeface.pth")
blazeface.load_anchors("weights/anchors.npy")


def get_face_frames(frames):
        
        # frames : torch.Tensor (frame_num, 128, 128, 3)
        faces = np.zeros((frames.shape[0], 128, 128, 3))
        # send to blaze face as (frame_num, 3, 128, 128)
        detections = blazeface.predict_on_batch(frames.permute(0, 3, 1, 2))
        # detections: List (frame_num, 17)
        
        for i, detection in enumerate(detections):
            
            face = get_face(frames[i], detection)
            faces[i] = cv2.resize(face, (128, 128))
            
        return torch.from_numpy(faces.astype(np.float32)).permute(0, 3, 1, 2)
    
def get_face(img, detections):

    if isinstance(detections, torch.Tensor):
        detections = detections.cpu().numpy()

    if detections.ndim == 1:
        detections = np.expand_dims(detections, axis=0)
    

    if detections.shape[0] == 0:
        print("No Face Detected")
        return img.cpu().numpy()
    
   
    ymin = max(detections[0, 0] * img.shape[0], 0)
    xmin = max(detections[0, 1] * img.shape[1], 0)
    ymax = max(detections[0, 2] * img.shape[0], 0)
    xmax = max(detections[0, 3] * img.shape[1], 0)
    
    return img[int(ymin):int(ymax), int(xmin):int(xmax), :].cpu().numpy()

cuda:0


In [11]:
def evaluate(model: nn.Module, data_loader: Any, device: torch.device, comment: str = ""):
    
    model.eval()
    
    total_samples = len(data_loader.dataset)
    correct_samples = 0
    total_loss = 0
    loss_history = []

    batch = 128
    count = 0
    with torch.no_grad():
        for (frames, arousal, valence) in tqdm(data_loader):
            
            num_frames = frames.shape[1]
            face_frames = get_face_frames(frames.squeeze())
            for i in range(num_frames // batch):
                
                l = batch * i
                r = min(batch * (i + 1), num_frames)
                
                face_batch = face_frames[l:r]
                a = arousal[:,l:r]
                v = valence[:,l:r]
                
                face_batch = face_batch.to(device)
                a = a.to(device)
                v = v.to(device)
                
                out = model(face_batch)
                
                a_pred = out[:, 0]
                v_pred = out[:, 1]
                
                a_loss = F.mse_loss(a_pred.squeeze(), a.squeeze())
                v_loss = F.mse_loss(v_pred.squeeze(), v.squeeze())
                loss = a_loss + v_loss

                total_loss += loss.item()
                count += 1
                
    avg_loss = total_loss / count
    wandb.log({'valid_loss': avg_loss})
    
    return avg_loss

In [12]:
def train(model, optimizer, epochs, data_loader, test_loader, device):
    
    # wandb.watch(model, log="all", log_freq=10)
    
    full_start = time.time()
    for epoch in tqdm(range(epochs)):
        
        model.train()
        model.to(device)
        print(f"Starting Epoch {epoch}")
        
        total_loss = 0
        epoch_time = time.time()
        num_batches = 0
        
        batch = 128
        count = 0
        
        for (frames, arousal, valence) in tqdm(data_loader):
            
            num_frames = frames.shape[1]
            face_frames = get_face_frames(frames.squeeze())
            for i in range(num_frames // batch):
                
                l = batch * i
                r = min(batch * (i + 1), num_frames)
                
                face_batch = face_frames[l:r]
                a = arousal[:,l:r]
                v = valence[:,l:r]
                
                face_batch = face_batch.to(device)
                a = a.to(device)
                v = v.to(device)
                
                optimizer.zero_grad()
                
                out = model(face_batch)
                
                a_pred = out[:,0]
                v_pred = out[:,1]
                
                a_loss = F.mse_loss(a_pred.squeeze(), a.squeeze())
                v_loss = F.mse_loss(v_pred.squeeze(), v.squeeze())
                loss = a_loss + v_loss
                
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                count += 1
            

                wandb.log({'batch_loss': loss.item()})
        
        print(f"Finished Epoch {epoch}")
        valid_loss = evaluate(model, test_loader, device)
        train_loss = evaluate(model, data_loader, device)
        
        wandb.log({
            'train_loss': train_loss,
            'valid_loss': valid_loss,
            'epoch_time_minutes': (time.time() - epoch_time) / 60
        })
        
        if i % 3 == 0:
            torch.save(model.state_dict(), 'trained_models/moody_much_checkpoint.pth')
        
    wandb.log({'full_run_time_minutes': (time.time() - full_start) / 60})
        

In [13]:
def train_model(hyperparameters, model=None, model_path=None):
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    
    with wandb.init(project=project_name, config=hyperparameters):
       
        config = wandb.config
        
        if model is None:
            print("Creating a New Model.")
            model = mobilenetv2()
            model.load_state_dict(torch.load('weights/mobilenetv2_128x128-fd66a69d.pth'))
            model.classifier = nn.Linear(model.classifier.in_features, 2)
            
        optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
           
        train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=cores)    
        
        train(model, optimizer, config.epochs, small_train_loader, small_valid_loader, device)

        test_loss = 0# evaluate(model, test_loader, device)
        
        wandb.log({'test_loss': test_loss})
    
    return model, test_loss

In [None]:
model, test_loss = train_model(hyperparameters)

cuda:0


[34m[1mwandb[0m: Currently logged in as: [33mnazirnayal98[0m (use `wandb login --relogin` to force relogin)


Creating a New Model.


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

Starting Epoch 0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detect

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detect

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detect

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detect

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detect

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detect

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detected
No Face Detect

In [None]:
summary(model.to(torch.device('cuda:0')), (3, 224, 224))