In [230]:
import torch
from torch.utils.data import Dataset
import pickle
from torch import nn
import torch.nn.functional as F
import numpy as np

In [231]:
with open('mnist_mad_tensor_activations.pkl', 'rb') as f:
    dataset = pickle.load(f)

In [232]:
dataset[0][0].shape, dataset[0][1].shape, dataset[0][2].shape, dataset[0][3]

(torch.Size([1, 256]), torch.Size([1, 128]), torch.Size([1, 10]), tensor(1.))

In [233]:
from torch.utils.data import Dataset

class ConcatDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        
    def __getitem__(self, index):
        input1, input2, input3, output = self.dataset[index]
        input_concat = torch.cat((input1, input2, input3), dim=1)
        return input_concat, output
    
    def __len__(self):
        return len(self.dataset)

In [234]:
concat_dataset = ConcatDataset(dataset)
concat_dataset[0][0].shape, concat_dataset[0][1]

(torch.Size([1, 394]), tensor(1.))

In [235]:
import random
from torch.utils.data import Subset

# define the size of the training set
train_size = int(0.8 * len(concat_dataset))

# create a list of indices for the training set and the evaluation set
indices = list(range(len(concat_dataset)))
random.shuffle(indices)
train_indices = indices[:train_size]
eval_indices = indices[train_size:]

# create a PyTorch Subset for the training set and the evaluation set
train_dataset = Subset(concat_dataset, train_indices)
eval_dataset = Subset(concat_dataset, eval_indices)

In [236]:
class ConcatClassifier(nn.Module):

    def __init__(self, concat_dim=394):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(concat_dim, 100),
            nn.ReLU(),
            nn.Linear(100, 20),
            nn.ReLU(),
            nn.Linear(20, 2),
        )
    
    def forward(self, x):
        return self.main(x)

concat_classifier = ConcatClassifier()

In [237]:
def evaluate(loader, model):
    with torch.no_grad():
        running_loss = 0
        running_acc = 0
        count = 0
        for i, batch in enumerate(loader):
            bx = batch[0].cuda()
            by = batch[1].cuda()

            count += by.size(0)
            logits = model(bx)
            pred = logits.argmax(dim=-1).squeeze()
            loss = F.binary_cross_entropy(pred, by)
            running_loss += loss.cpu().numpy()
            running_acc += (torch.max(logits, dim=1)[1].float() == by).float().sum(0).cpu().numpy()
        loss = running_loss / count
        acc = running_acc / count
    return loss, acc

In [238]:
def train_model(train_data, test_data, model, num_epochs=10, batch_size=64):
    """
    :param train_data: the data to train with
    :param test_data: the clean test data to evaluate accuracy on
    :param model: the model to train
    :param num_epochs: the number of epochs to train for
    :param batch_size: the batch size for training
    """
    train_loader = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size, shuffle=True)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader)*num_epochs)

    loss_ema = np.inf

    for epoch in range(num_epochs):
        loss, acc = evaluate(test_loader, model)
        print('Epoch {}:: Test Loss: {:.3f}, Test Acc: {:.3f}'.format(epoch, loss, acc))
        for i, (bx, by) in enumerate(train_loader):

            bx = bx.cuda()
            by = by.cuda()

            logits = model(bx)
            pred = logits.argmax(dim=-1).squeeze()
            loss = F.binary_cross_entropy(pred, by)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            if loss_ema == np.inf:
                loss_ema = loss.item()
            else:
                loss_ema = loss_ema * 0.95 + loss.item() * 0.05

            if i % 500 == 0:
                print('Train loss: {:.3f}'.format(loss_ema))  # to get a rough idea of training loss

    loss, acc = evaluate(test_loader, model)
    
    print('Final Metrics:: Test Loss: {:.3f}, Test Acc: {:.3f}'.format(
        loss, acc))
    
    return loss, acc

In [239]:
model = concat_classifier.cuda()
loss, acc = train_model(train_dataset, eval_dataset, model,
                                          num_epochs=10, batch_size=256)
loss, acc

RuntimeError: ignored