In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
import pickle
from glob import glob
import os
import yaml
from easydict import EasyDict as edict

In [None]:
import sys

sys.path.append('../')

In [None]:
from model.controlmodel import *
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split, ConcatDataset

In [None]:
import pytorch_lightning as pl
from torchmetrics import Accuracy

In [None]:
config_file = glob('../config/control_model/control_mlp_cifar.yaml')[0]
config = edict(yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader))

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

train_dataset = datasets.CIFAR10(root='../xor_neuron_data/data',
                                train=True,
                                transform=transform,
                                download=False)

validation_dataset = datasets.CIFAR10(root='../xor_neuron_data/data',
                                train=False,
                                transform=transform,
                                download=False)

In [None]:
train_loader = DataLoader(dataset=train_dataset,
                                 batch_size=16,
                                 shuffle=False)

validation_loader = DataLoader(dataset=validation_dataset,
                                 batch_size=16,
                                 shuffle=False)


In [None]:
class MLP_model(nn.Module):

    def __init__(self, config):
        super(MLP_model, self).__init__()
        self.control = config
        self.input_dim = config.model.input_dim
        self.num_classes = config.model.num_classes
        self.dropout = config.model.dropout
        self.hidden_dim = config.model.hidden_dim

        self.loss = config.model.loss

        self.model = nn.ModuleList()

        for i in range(len(self.hidden_dim)):
            if i == 0:
                self.model.append(
                    nn.Linear(self.input_dim, self.hidden_dim[i])
                )

            else:
                self.model.append(
                    nn.Linear(self.hidden_dim[i - 1], self.hidden_dim[i])
                )

        self.fc_out = nn.Linear(self.hidden_dim[-1], self.num_classes)

        self.drop_layer = nn.Dropout(p=self.dropout)
        self.activation_fnc = nn.ReLU()


    def forward(self, x):
        x = x.reshape(-1, np.array(x.shape[1:]).prod())

        for i, fc in enumerate(self.model):
            if i == 0:
                out = fc(x)
            else:
                out = fc(out)
            out = self.activation_fnc(nn.LayerNorm(out.size()[1:], elementwise_affine=False)(out))
            out = self.drop_layer(out)

        out = self.fc_out(out)

        return out

In [None]:
class Classifier(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.model = Control_MLP(config)
        self.acc_fnc = Accuracy()
        
    def forward(self, x, y):
        return self.model(x, y)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits, loss = self(x, y)
        metrics = {'train_loss': loss}
        self.log_dict(metrics)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        logits, loss = self(x, y)
        
        acc = self.acc_fnc(logits, y)
        
        metrics = {'val_acc': acc, 'val_loss': loss}
        self.log_dict(metrics)

    def test_step(self, batch, batch_idx):
        x, y = batch

        logits, loss = self(x)
        acc = self.acc_fnc(logits, y)
        
        metrics = {'test_acc': acc, 'test_loss': loss}
        self.log_dict(metrics)
    
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=0.001,
            weight_decay=0.00001,
            momentum=0.9,
            nesterov=True,
        )
        
        
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[1000],
            gamma=0.001)
        
        return [optimizer], [scheduler]

In [None]:
control_model = Classifier(config)

In [None]:
control_model

In [None]:
trainer = pl.Trainer(max_epochs=3, gpus=0)
trainer.fit(control_model, train_loader, validation_loader)