In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
import pickle
from glob import glob
import os
import yaml
from easydict import EasyDict as edict

In [3]:
import sys

sys.path.append('../')

In [4]:
from model.controlmodel import *
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split, ConcatDataset

In [5]:
import pytorch_lightning as pl
from torchmetrics import Accuracy

In [6]:
config_file = glob('../config/control_model/control_mlp_cifar.yaml')[0]
config = edict(yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader))

In [7]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

train_dataset = datasets.CIFAR10(root='../xor_neuron_data/data',
                                train=True,
                                transform=transform,
                                download=False)

validation_dataset = datasets.CIFAR10(root='../xor_neuron_data/data',
                                train=False,
                                transform=transform,
                                download=False)

In [8]:
train_loader = DataLoader(dataset=train_dataset,
                                 batch_size=16,
                                 shuffle=False)

validation_loader = DataLoader(dataset=validation_dataset,
                                 batch_size=16,
                                 shuffle=False)


In [10]:
class Classifier(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.model = Control_MLP(config)
        self.acc_fnc = Accuracy()
        
    def forward(self, x, y):
        return self.model(x, y)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits, loss = self(x, y)
        metrics = {'train_loss': loss}
        self.log_dict(metrics)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        logits, loss = self(x, y)
        
        acc = self.acc_fnc(logits, y)
        
        metrics = {'val_acc': acc, 'val_loss': loss}
        self.log_dict(metrics)

    def test_step(self, batch, batch_idx):
        x, y = batch

        logits, loss = self(x)
        acc = self.acc_fnc(logits, y)
        
        metrics = {'test_acc': acc, 'test_loss': loss}
        self.log_dict(metrics)
    
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=0.001,
            weight_decay=0.00001,
            momentum=0.9,
            nesterov=True,
        )
        
        
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[1000],
            gamma=0.001)
        
        return [optimizer], [scheduler]

In [11]:
control_model = Classifier(config)

In [12]:
control_model

Classifier(
  (model): Control_MLP(
    (model): ModuleList(
      (0): Linear(in_features=3072, out_features=124, bias=True)
      (1): Linear(in_features=124, out_features=124, bias=True)
      (2): Linear(in_features=124, out_features=124, bias=True)
    )
    (fc_out): Linear(in_features=124, out_features=10, bias=True)
    (drop_layer): Dropout(p=0.5, inplace=False)
    (loss_func): CrossEntropyLoss()
    (activation_fnc): ReLU()
  )
  (acc_fnc): Accuracy()
)

In [13]:
trainer = pl.Trainer(max_epochs=3, gpus=0)
trainer.fit(control_model, train_loader, validation_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name    | Type        | Params
----------------------------------------
0 | model   | Control_MLP | 413 K 
1 | acc_fnc | Accuracy    | 0     
----------------------------------------
413 K     Trainable params
0         Non-trainable params
413 K     Total params
1.653     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]