In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset

import os

import dataloader
# import multimodal_ae_models as mm_ae_models
import multimodal_ae_simple_models as mm_ae_simple_models

from tqdm.autonotebook import tqdm
from notify_run import Notify



## Trainer Class

In [2]:
class ModelTrainer:
    def __init__(self, model, train_loader, val_loader, 
                 optimizer, critereon, scheduler,
                 device, 
                 notify, run_id = "test"):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.critereon = critereon
        self.scheduler = scheduler
        self.device = device
        
        # Saved metrics per epoch
        self.train_losses = []
        self.val_losses = []
        self.val_acc = []
        self.curr_epoch = 0 # current epoch
        
        # Storing, logging and notification
        self.run_id = run_id # run_id to define storage location
        self.notify = notify
        self.log_file_path = os.path.join('experiments', self.run_id, 'log.txt')
        self._init_storage()
        
    def load(self, run_id, curr_epoch):
        file_path = os.path.join("experiments", run_id, "trainer-epoch-{}.pkl".format(curr_epoch))
        print("Loading from "+file_path)
        trainer_state_dict = torch.load(file_path)
        
        self.model.load_state_dict(trainer_state_dict["model_state_dict"])
        # self.train_loader
        # self.val_loader
        optimizer = optim.SGD(self.model.parameters(), lr=0.01)
        optimizer.load_state_dict(trainer_state_dict["optimizer_state_dict"])
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        self.optimizer = optimizer
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, threshold=0.0001, verbose=True,
                                          patience=2, min_lr=1e-6)
        scheduler.load_state_dict(trainer_state_dict["scheduler_state_dict"])
        self.scheduler = scheduler
        
        # self.critereon
        # self.device
        
        self.train_losses = trainer_state_dict["train_losses"]
        self.val_losses = trainer_state_dict["val_losses"]
        self.val_acc = trainer_state_dict["val_acc"]
        self.curr_epoch = curr_epoch
        
        self.run_id = run_id
        # self.notify
        self.log_file_path = os.path.join('experiments', self.run_id, 'log.txt')
        self._init_storage()
        
        
    def _init_storage(self):
        os.makedirs("./experiments/{}".format(self.run_id), exist_ok=True)
        print("Saving models and run statistics to ./experiments/%s" % self.run_id)
        
    def _train_notify_and_log(self, epoch_loss):
        train_str = '[TRAIN]  Epoch %d Loss: %.4f' % (self.curr_epoch, epoch_loss)
        print(train_str)
        with open(self.log_file_path, 'a+') as f:
            f.write(train_str+"\n")
        self.notify.send("{} {}".format(self.run_id, train_str))
    
    # Train one epoch
    def train_epoch(self):
        self.model.train()
        self.model.to(self.device)
        
        epoch_loss = 0.0
        num_batches = 0
        for videos, audios, _, _, labels in tqdm(self.train_loader): # ignore lens because single frame
            videos, audios, labels = videos.to(self.device), audios.to(self.device), labels.to(self.device)
            video_out, audio_out, binary_out, classification_embedding, video_embed, audio_embed = model(videos, audios)
            batch_loss = self.model.loss((video_out, audio_out, binary_out), (video_embed, audio_embed, labels))
            
            self.optimizer.zero_grad()
            batch_loss.backward()
            self.optimizer.step()
            
            epoch_loss += batch_loss.item() # add by average loss over batch
            num_batches += 1
        epoch_loss /= num_batches # divide by num batches
        
        self.curr_epoch += 1 # put here for 1 index
        self._train_notify_and_log(epoch_loss)
        self.train_losses.append(epoch_loss)
    
    """
        return batch_loss: 
        batch_corrects: num correct predictions this batch
        batch_size: num total samples in this batch
    """
    def _val_batch(self, videos, audios, labels):
        videos, audios, labels = videos.to(self.device), audios.to(self.device), labels.to(device)
        video_out, audio_out, binary_out, classification_embedding, video_embed, audio_embed = model(videos, audios)
        batch_loss = self.model.loss((video_out, audio_out, binary_out), (video_embed, audio_embed, labels)).item()
        
        # calculate acc
        _, max_preds = torch.max(binary_out, 1)
        batch_corrects = (max_preds == labels).sum().item()
        batch_size = labels.shape[0]
        
#         # DEBUG
#         if batch_loss < 100 and self.good is None:
#             self.good = (labels.cpu(), preds.cpu(), batch_loss)
#         if (batch_loss > 100 and self.bad is None):
#             self.bad = (videos.cpu(), audios.cpu(), labels.cpu(), preds.cpu(), batch_loss)
#             raise NotImplemented
        return batch_loss, batch_corrects, batch_size
    
    def _val_notify_and_log(self, epoch_loss, epoch_acc):
        val_str = '[VAL]  Epoch {} Loss: {:.4f} Acc: {:.2f}%'.format(self.curr_epoch, epoch_loss, epoch_acc)
        print(val_str)
        with open(self.log_file_path, 'a+') as f:
            f.write(val_str+"\n")
        self.notify.send("{} {}".format(self.run_id, val_str))
    
    # Validate one epoch
    def val_epoch(self):
        self.model.eval()
        self.model.to(self.device)
        
        epoch_loss = 0.0   
        epoch_corrects = 0
        num_samples = 0    
        num_batches = 0
        
        # DEBUG
        self.good = None
        self.bad = None
        for videos, audios, _, _, labels in tqdm(self.val_loader):
            batch_loss, batch_corrects, batch_size = self._val_batch(videos, audios, labels)
            epoch_loss += batch_loss
            epoch_corrects += batch_corrects
            num_samples += batch_size
            num_batches += 1
        
        epoch_loss /= num_batches
        epoch_acc = epoch_corrects / num_samples * 100.0
        self.scheduler.step(epoch_loss)
        
        self._val_notify_and_log(epoch_loss, epoch_acc)
        self.val_losses.append(epoch_loss)
        self.val_acc.append(epoch_acc)
    
    """
        save model weights
    """
    def save(self):
        trainer_state_dict = {
            "model_tostring": str(self.model),
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'val_acc': self.val_acc
        }
        torch.save(trainer_state_dict, os.path.join("experiments", self.run_id, "trainer-epoch-{}.pkl".format(self.curr_epoch)))
    
    """
        save run statistics (loss and acc)
    """
    def save_run_stats(self):
        run_stats = {
            "train_losses" : self.train_losses,
            "val_losses" : self.val_losses,
            "val_acc" : self.acc
        }
        np.save(os.path.join("experiments", self.run_id, "run_stats.npy"), run_stats)

## Load Data & Create Dataloader
Load data for single frame prediction

In [3]:
train_video_dataset = dataloader.get_dataset(dataloader.TRAIN_JSON_PATH, dataloader.SINGLE_FRAME)
train_audio_dataset = dataloader.AudioDataset()
val_video_dataset = dataloader.get_dataset(dataloader.VAL_JSON_PATH, dataloader.SINGLE_FRAME)
val_audio_dataset = dataloader.AudioDataset()

loaded 733589 images 
loaded 137015 images 


In [4]:
BATCH_SIZE = 64

train_loader = dataloader.AVDataLoader(train_video_dataset, train_audio_dataset, batch_size=BATCH_SIZE, shuffle=True, single_frame=True)
val_loader = dataloader.AVDataLoader(val_video_dataset, val_audio_dataset, batch_size=BATCH_SIZE, shuffle=False, single_frame=True)

print("Verifying data sizes")
for v, a, _, _, l in train_loader:
    print('videos shape:', v.shape) # batch_size*3(channel)*224*224
    print('audios shape:', a.shape) # batch_size*5*50(channel)
    print('labels shape:', l.shape) # batch_size
    break

Verifying data sizes
videos shape: torch.Size([53, 3, 224, 224])
audios shape: torch.Size([53, 5, 50])
labels shape: torch.Size([53])


## Data Exploration

In [5]:
# train_labels = []
# for _, _, _, _, l in tqdm(train_loader):
#     train_labels.append(l)

# print("train_labels", torch.cat(train_labels))
# print(torch.cat(train_labels).shape)

# val_labels = []
# for _, _, _, _, l in tqdm(val_loader):
#     val_labels.append(l)

# print("val_labels", torch.cat(val_labels))
# print(torch.cat(val_labels).shape)

## Define Model and Training Parameters

In [6]:
run_id = "MultiModalAESimpleDense" #"test" #
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
notify = Notify(endpoint="https://notify.run/Dbnkja3hR3rG7MuV")

model = mm_ae_simple_models.MultiModalAESimpleDensenetModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
critereon = torch.nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, threshold=0.0001, verbose=True,
                                                      patience=2, min_lr=1e-6)

trainer = ModelTrainer(model, train_loader, val_loader, optimizer, critereon, scheduler, device, notify, run_id)

Saving models and run statistics to ./experiments/MultiModalAESimpleDense


In [7]:
trainer.model.densenet.load_state_dict(torch.load("densenet169.pth")["model_state_dict"])
for param in model.densenet.parameters():
    param.requires_grad = False

In [8]:
run_id = "MultiModalAESimpleDense" #"test" #
curr_epoch = 1
trainer.load(run_id, curr_epoch)

Loading from experiments/MultiModalAESimpleDense/trainer-epoch-1.pkl
Saving models and run statistics to ./experiments/MultiModalAESimpleDense


In [9]:
# # self.bad.append((videos, audios, labels, preds, batch_loss))
# # self.bad = (videos.cpu(), audios.cpu(), labels.cpu(), preds.cpu(), batch_loss)

# # print(trainer.bad)
# model.to("cuda")
# print(model(trainer.bad[0][:2].cuda(),trainer.bad[1][:2].cuda()))
# print(trainer.bad[3], trainer.bad[2])
# print(trainer.critereon(trainer.bad[3], trainer.bad[2]))

## Train

In [None]:
NUM_EPOCHS = 20
for epoch in tqdm(range(NUM_EPOCHS)):
    trainer.train_epoch()
    trainer.val_epoch()
    trainer.save()
trainer.save_run_stats()

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, max=11463), HTML(value='')))


[TRAIN]  Epoch 2 Loss: 0.1340


HBox(children=(IntProgress(value=0, max=2141), HTML(value='')))


[VAL]  Epoch 2 Loss: 0.0985 Acc: 95.60%


HBox(children=(IntProgress(value=0, max=11463), HTML(value='')))


[TRAIN]  Epoch 3 Loss: 0.0160


HBox(children=(IntProgress(value=0, max=2141), HTML(value='')))


[VAL]  Epoch 3 Loss: 0.1054 Acc: 95.37%


HBox(children=(IntProgress(value=0, max=11463), HTML(value='')))


[TRAIN]  Epoch 4 Loss: 0.0141


HBox(children=(IntProgress(value=0, max=2141), HTML(value='')))


[VAL]  Epoch 4 Loss: 0.1047 Acc: 95.42%


HBox(children=(IntProgress(value=0, max=11463), HTML(value='')))