# Trivial Model Approach

Every Spacetime frame is passed to a pretrained Image Classifier (ResNet50), and its embedding is extracted. These embeddings are stacked along the temporal axis and fed into an LSTM. The output of a single LSTM cell is a 42 vector, which are again stacked along the temporal axis to create a cochleagram. Reconstruction loss is computed between the predicted cochleagram and the true cochleagram and the loss is backpropagated to train the network.

In [1]:
import sys

sys.path.append('..')

import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl

from torch.utils.data import DataLoader
from torchvision.models import resnet18, ResNet18_Weights

from VISTorchUtils import VISDataset, VISLoss
from VISDataPoint import VISDataPoint
from utils import visCollate

In [2]:
trainDataset = VISDataset('/scratch/vis_data/train')
valDataset = VISDataset('/scratch/vis_data/test')

In [3]:
trainDataLoader = DataLoader(trainDataset, batch_size=2, collate_fn=visCollate, num_workers=4)
valDataLoader = DataLoader(valDataset, batch_size=2, collate_fn=visCollate, num_workers=4)

In [4]:
class VISTrivialModel(pl.LightningModule):

    def __init__(self, outputSize:int):
        super().__init__()
        self.featureExtractor = resnet18(ResNet18_Weights.DEFAULT)
        self.featureExtractor.fc = nn.Identity()

        self.lstm = nn.LSTM(1024, 1024, batch_first=True)
        self.fc = nn.Linear(1024, outputSize)


    def forward(self, stFrames, frame0):

        # stFrames: batchx45x224x224x3
        # frame0: batchx224x224x3

        stFrameFeatures = []
        for i in range(stFrames.shape[1]):
            currStFrame = stFrames[:,i,:,:,:].squeeze(1)
            currStFrameFeatures = self.featureExtractor(currStFrame)
            stFrameFeatures.append(currStFrameFeatures)
        stFrameFeatures = torch.stack(stFrameFeatures, dim=1)

        frame0Features = self.featureExtractor(frame0).unsqueeze(1).repeat(1, stFrames.shape[1], 1)
        X = torch.cat([stFrameFeatures, frame0Features], dim=2)
        
        # X is the input to the LSTM -> batchx45x1024
        X, _ = self.lstm(X)

        # Pass the LSTM output of each timestep through a linear layer
        out = []

        for i in range(X.shape[1]):
            currOut = self.fc(X[:,i,:])
            out.append(currOut)
        
        out = torch.stack(out, dim=1).transpose(1,2)
        
        return out
    
    def training_step(self, batch, batch_idx):
        coch, stFrames, frame0, material = batch
        out = self(stFrames, frame0)
        loss = VISLoss()(out, coch)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        coch, stFrames, frame0, material = batch
        out = self(stFrames, frame0)
        loss = VISLoss()(out, coch)
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [5]:
model = VISTrivialModel(outputSize=42)



In [6]:
trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=10)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(model, trainDataLoader, valDataLoader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type   | Params
--------------------------------------------
0 | featureExtractor | ResNet | 11.2 M
1 | lstm             | LSTM   | 8.4 M 
2 | fc               | Linear | 43.1 K
--------------------------------------------
19.6 M    Trainable params
0         Non-trainable params
19.6 M    Total params
78.465    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
