In [18]:
import os
import json
from PIL import Image

import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision import datasets, transforms, models
import torch.nn as nn
from torchvision.models import resnet50
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR


In [19]:
# Check if a GPU is available
if torch.cuda.is_available():
    # Request GPU device 0
    device = torch.device("cuda:0")
    print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
    # If no GPU is available, fall back to CPU
    device = torch.device("cpu")
    print("No GPU available, using CPU.")

Using GPU: NVIDIA RTX A5000


In [20]:
class TaskDataset(Dataset):
    # todo: this is not the most efficient way to access data, since each time it has to read from the directory 
    def __init__(self, root_dir,):
        self.root_dir = root_dir
        # preprocessing steps for pretrained ResNet models
        self.transform = transforms.Compose([
                            transforms.Resize(224),
                            transforms.CenterCrop(224), # todo: to delete for shapenet task; why?
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                        ])

        # check the size of the dataset
        self.dataset_size = 0
        items = os.listdir(self.root_dir)
        for item in items:
            item_path = os.path.join(self.root_dir, item)
            # Check if the item is a directory
            if os.path.isdir(item_path):
                self.dataset_size += 1

    def __len__(self):
        return self.dataset_size

    def __getitem__(self, idx):
        trial_path = os.path.join(self.root_dir, "trial%d"%idx)
        # todo: only consider 2 frames in this case => need to make it generalizable
        images = []
        for i in range(2):
            image = Image.open(os.path.join(trial_path, "epoch%d.png"%i))
            image = self.transform(image)
            
            images.append(image)
        images = np.stack(images) # (2*3*224*224)
        with open(os.path.join(trial_path, "trial_info"), 'r') as json_file:
            data = json.load(json_file)
            # 'data' now contains the contents of the JSON file as a Python dictionary
            actions = self._action_map(data["answers"])[-1]

        return images, torch.tensor(actions)
    
    def _action_map(self, actions):
        updated_actions = []
        for action in actions:
            if action == "null":
                updated_actions.append(2)
            elif action == "false":
                updated_actions.append(0)
            elif action == "true":
                updated_actions.append(1)
        return updated_actions


In [30]:
# define the network
IMGM_PATH = 'tutorials/offline_models/resnet/resnet'


activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook


class CNNRNNNet(nn.Module):

    def __init__(self, hidden_size, output_size = 3,):
        super().__init__()

        # set up the CNN model
        self.cnnmodel = torch.load(IMGM_PATH, map_location=device)
        # freeze layers of cnn model
        for paras in self.cnnmodel.parameters():
            paras.requires_grad = False
        # get relu activation of last block of resnet50
        
        self.cnnmodel.layer4[2].relu.register_forward_hook(get_activation('relu'))

        self.cnnlayer = torch.nn.Conv2d(2048, hidden_size, 1) # we can also bring the resnet embedding dim to a number different from hidden size

        self.input_size = hidden_size*7*7
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.in2hidden = nn.Linear(self.input_size, hidden_size)
        self.layer_norm_in = nn.LayerNorm(self.hidden_size)
        
        self.rnn = nn.RNN(
            input_size = self.hidden_size, 
            hidden_size = self.hidden_size,
            nonlinearity = "relu", # guarnatee positive activations
            )

        self.layer_norm_rnn = nn.LayerNorm(self.hidden_size)
        self.hidden2output = nn.Linear(self.hidden_size, self.output_size)


    def forward(self, input_img, hidden_state = None, is_noise = False,):
        # preprocess image with resnet
        self.batch_size = input_img.shape[0]
        self.seq_len = input_img.shape[1]
        
        x = torch.swapaxes(input_img, 0, 1).float()# (seq_len, batchsize, nc, w, h)
        
        x_acts = []
        cnn_acts = []
        for i in range(self.seq_len):
            temp = self.cnnmodel(x[i,:,:,:,:])
            cnn_acts.append(activation["relu"])
            x_act = self.cnnlayer(activation["relu"])
            x_acts.append(x_act) # (batchsize, nc, w, h) = (batchsize, 2048, 7,7)
        
        x_acts = torch.stack(x_acts, axis = 0) # (seqlen, batchsize,nc, w,h)
        self.cnn_acts = torch.stack(cnn_acts, axis = 0) # (seqlen, batchsize, nc, w,h)
        self.cnn_acts_down = x_acts
        
        x_acts = x_acts.reshape(x_acts.shape[0], self.batch_size, -1)
        
        if hidden_state == None:
            self.hidden_state = self.init_hidden(batch_size = self.batch_size)

        hidden_x = self.layer_norm_in(torch.relu(self.in2hidden(x_acts.float()))).to(device)
        
        rnn_output, _ = self.rnn(hidden_x, self.hidden_state.to(device))
        
        rnn_output = self.layer_norm_rnn(rnn_output)
        
        out = self.hidden2output(torch.tanh(rnn_output))
        
        
        return out[-1, :, :]
        


    def init_hidden(self, batch_size):
        return nn.init.kaiming_uniform_(torch.empty(1, batch_size, self.hidden_size))


In [31]:
train_TD = TaskDataset("datasets/train_big")
val_TD = TaskDataset("datasets/val_big")
                 
batch_size = 64
data_loaders = [DataLoader(train_TD, batch_size=batch_size, shuffle=True),
                DataLoader(val_TD, batch_size=batch_size, shuffle=False)]
                
model = CNNRNNNet(hidden_size = 256, output_size = 3,).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr = 1e-3)

In [32]:
# Calculates the number of correct null action predictions and the number of correct non-null action predictions
def correct(preds, targs):
    null_idxs = torch.where(targs.cpu() == 2)
    non_null_idxs = torch.where(targs.cpu() < 2)
    
    null_preds = preds[null_idxs]
    non_null_preds = preds[non_null_idxs]
    
    c_null = torch.sum(null_preds == targs[null_idxs])
    n_null = len(null_preds)
    null_acc = c_null/n_null
    
    c_non_null = torch.sum(non_null_preds == targs[non_null_idxs])
    n_non_null = len(non_null_preds)
    non_null_acc = c_non_null/n_non_null
    
    return null_acc, non_null_acc

In [38]:
for epoch in range(10):
    for i, data_loader in enumerate(data_loaders):
        if i == 0: mode = "train"
        else: mode = "val"
        
        null_accs = []
        non_null_accs = []
        train_losses = []
        for images, actions in data_loader:
            if mode == "train":
                model.train()
                optimizer.zero_grad()
            else:
                model.eval()
            output = model(images.to(device))
            print(actions.shape)
            train_loss = criterion(output, actions.type(torch.LongTensor).reshape(-1).to(device))
            train_losses.append(train_loss.item())
#             print(train_loss)
            if mode == "train":
                train_loss.backward()
                optimizer.step()
            _, predicted = torch.max(output.data, 2)
            predicted = predicted.permute(1,0).reshape(-1)
            acc = torch.sum(predicted == actions.reshape(-1).to(device))/(len(predicted))
            accs.append(acc)
        if mode == "train":
            print("epoch %d, current loss %.2f, TRAIN current acc %.2f" % (epoch, sum(train_losses)/len(train_losses), sum(accs)/len(accs)))
        elif mode == "val":
            print(predicted)
            print(actions.reshape(-1))
            print("epoch %d, current loss %.2f, VAL current acc %.2f" % (epoch, sum(train_losses)/len(train_losses), sum(accs)/len(accs)))

torch.Size([64, 2])


ValueError: Expected input batch_size (64) to match target batch_size (128).