In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset

import os

import dataloader
import MultiModal_Model as mm_models

from tqdm.autonotebook import tqdm
from notify_run import Notify



## Trainer Class

In [2]:
class ModelTrainer:
    def __init__ (self, model, test_loader, critereon, device, notify, run_id, num_epochs):
        self.model = model
        self.test_loader = test_loader
        self.critereon = critereon
        self.device = device
        self.notify = notify
        self.run_id = run_id
        
        self.log_file_path = os.path.join('experiments', self.run_id, 'log.txt')
        
        self.test_losses = []
        self.test_acc = []
        
    """
        return batch_loss: 
        batch_corrects: num correct predictions this batch
        batch_size: num total samples in this batch
    """
    def _test_batch(self, videos, audios, labels):
        videos, audios, labels = videos.to(self.device), audios.to(self.device), labels.to(device)
        _, preds = model(videos, audios) # linear output
        
        # calculate loss
        batch_loss = self.critereon(preds, labels).item()
        
        # calculate acc
        _, max_preds = torch.max(preds, 1)
        batch_corrects = (max_preds == labels).sum().item()
        batch_size = labels.shape[0]
        return batch_loss, batch_corrects, batch_size
    
    def _test_notify_and_log(self, epoch_loss, epoch_acc):
        test_str = '[TEST]  Epoch {} Loss: {:.4f} Acc: {:.2f}%'.format(self.curr_epoch, epoch_loss, epoch_acc)
        print(test_str)
        with open(self.log_file_path, 'a+') as f:
            f.write(test_str+"\n")
        self.notify.send("{} {}".format(self.run_id, test_str))
        
    def test_epoch(self, curr_epoch):
        self.curr_epoch = curr_epoch
        trainer_state = torch.load(os.path.join("experiments", self.run_id, "trainer-epoch-{}.pkl".format(curr_epoch)))
        self.model.load_state_dict(trainer_state["model_state_dict"])
        
        self.model.eval()
        self.model.to(self.device)
        
        epoch_loss = 0.0   
        epoch_corrects = 0
        num_samples = 0    
        num_batches = 0
        
        for videos, audios, _, _, labels in tqdm(self.test_loader):
            batch_loss, batch_corrects, batch_size = self._test_batch(videos, audios, labels)
            epoch_loss += batch_loss
            epoch_corrects += batch_corrects
            num_samples += batch_size
            num_batches += 1
        
        epoch_loss /= num_batches
        epoch_acc = epoch_corrects / num_samples * 100.0
        
        self._test_notify_and_log(epoch_loss, epoch_acc)
        self.test_losses.append(epoch_loss)
        self.test_acc.append(epoch_acc)

## Load Data & Create Dataloader
Load data for single frame prediction

In [3]:
test_video_dataset = dataloader.get_dataset(dataloader.TEST_JSON_PATH, dataloader.SINGLE_FRAME)
test_audio_dataset = dataloader.AudioDataset()

loaded 147516 images 


In [4]:
BATCH_SIZE = 64
test_loader = dataloader.AVDataLoader(test_video_dataset, test_audio_dataset, batch_size=BATCH_SIZE, shuffle=False, single_frame=True)

print("Verifying data sizes")
for v, a, _, _, l in test_loader:
    print('videos shape:', v.shape) # batch_size*3(channel)*224*224
    print('audios shape:', a.shape) # batch_size*5*50(channel)
    print('labels shape:', l.shape) # batch_size
    break

Verifying data sizes
videos shape: torch.Size([64, 3, 224, 224])
audios shape: torch.Size([64, 5, 50])
labels shape: torch.Size([64])


## Define Model and Training Parameters

In [5]:
run_id = "MultiModalSimpleConcatDeepfakeInitFreezeModel" #"test" #
num_epochs = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
notify = Notify(endpoint="https://notify.run/Dbnkja3hR3rG7MuV")

model = mm_models.MultiModalSimpleConcatModel()
critereon = torch.nn.CrossEntropyLoss()

tester = ModelTrainer(model, test_loader, critereon, device, notify, run_id, num_epochs)

In [6]:
curr_epoch = 2
tester.test_epoch(curr_epoch)

HBox(children=(IntProgress(value=0, max=2305), HTML(value='')))


[TEST]  Epoch 2 Loss: 0.1907 Acc: 95.48%
