<a href="https://colab.research.google.com/github/HyberionBrew/GTN/blob/main/GTN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
! pip install pytorch-lightning
! pip install pytorch-lightning-bolts
! pip install wandb
! pip install optuna

In [2]:
from argparse import ArgumentParser
import sys
import copy
import time
import json
import platform

import psutil
import tqdm
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd

import torch
import pytorch_lightning as pl
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

from torchvision.datasets.mnist import MNIST
from torchvision import transforms

import optuna
import wandb

import pprint
import torchmetrics
import torch.nn as nn
from pytorch_lightning.loggers import WandbLogger


In [4]:
!nvidia-smi

Fri May 20 12:35:10 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [21]:
from pl_bolts.datamodules import MNISTDataModule


In [59]:

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([ transforms.Grayscale(3),
                                             transforms.ToTensor(), 
                                             transforms.Normalize((0.1307,), (0.3081,)),
                                             #transforms.Resize(64),
                                            ])

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage= None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

        if stage == "predict" or stage is None:
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)

In [77]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), 
                                             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                                             #transforms.Resize(64),
                                            ])

    def prepare_data(self):
        # download
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage= None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            cifar10_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.cifar10_train, self.cifar10_val = random_split(cifar10_full, [len(cifar10_full)- int(len(cifar10_full)*0.2),  int(len(cifar10_full)*0.2)])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.cifar10_test = CIFAR10(self.data_dir, train=False, transform=self.transform)

        if stage == "predict" or stage is None:
            self.cifar10_predict = CIFAR10(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.cifar10_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.cifar10_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.cifar10_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.cifar10_predict, batch_size=32)

In [78]:
mnist = MNISTDataModule()
mnist.prepare_data()
mnist.setup()
train_loader = mnist.train_dataloader()
val_loader = mnist.val_dataloader()
test_loader = mnist.test_dataloader()

In [79]:
from torchvision.datasets.cifar import CIFAR10

In [80]:
cifar10 = CIFAR10DataModule()
cifar10.prepare_data()
cifar10.setup()
train_loader = cifar10.train_dataloader()
val_loader = cifar10.val_dataloader()
test_loader = cifar10.test_dataloader()

Files already downloaded and verified
Files already downloaded and verified


In [81]:
wandb.login()

True

In [82]:
"""
class model(nn.Module):
    def __init__(self):
        
    def forward(self, x):
        return x
"""
import torchvision.models as models
model_frozen = models.resnet18(pretrained=True)
for param in model_frozen.parameters():
    param.requires_grad = False

model_frozen.fc= nn.LazyLinear(10, bias=True)
#model_frozen.fc.weight.requires_grad = True
model_full_train = models.resnet18(pretrained=True)
model_full_train.fc= nn.LazyLinear(10, bias=True)



In [83]:
class LigModel(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.total = 0
        self.correct = 0
        self.model = model
         # log hyperparameters
        #self.save_hyperparameters()

        # compute the accuracy -- no need to roll your own!
        self.train_acc = torchmetrics.Accuracy()
        self.valid_acc = torchmetrics.Accuracy()
        self.test_acc = torchmetrics.Accuracy()
            
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        # logging metrics we calculated by hand
        self.log('train/loss', loss,on_step=True, on_epoch=True, prog_bar=True, logger=True)
        # logging a pl.Metric
        self.train_acc(y_hat, y)
        self.log('train/acc', self.train_acc, on_epoch=True)
        return loss

    def train_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        _, predicted = torch.max(y_hat.data, 1)
        # logging metrics we calculated by hand
        self.log('train/loss', loss, on_epoch=False, logger=True)
        # logging a pl.Metric
        self.valid_acc(y_hat, y)
        self.log('train/acc', self.valid_acc,prog_bar=True, on_epoch=False)
        return y.size(0),(predicted == y).sum().item()
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        _, predicted = torch.max(y_hat.data, 1)
        # logging metrics we calculated by hand
        self.log('validation/loss', loss, on_epoch=True, logger=True)
        # logging a pl.Metric
        self.valid_acc(y_hat, y)
        self.log('validation/acc', self.valid_acc,prog_bar=True, on_epoch=True)
        return y.size(0),(predicted == y).sum().item()


    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)


In [84]:
from torchsummary import summary
summary(model_full_train, next(iter(mnist.val_dataloader()))[0][1].shape,device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 14, 14]           9,408
       BatchNorm2d-2           [-1, 64, 14, 14]             128
              ReLU-3           [-1, 64, 14, 14]               0
         MaxPool2d-4             [-1, 64, 7, 7]               0
            Conv2d-5             [-1, 64, 7, 7]          36,864
       BatchNorm2d-6             [-1, 64, 7, 7]             128
              ReLU-7             [-1, 64, 7, 7]               0
            Conv2d-8             [-1, 64, 7, 7]          36,864
       BatchNorm2d-9             [-1, 64, 7, 7]             128
             ReLU-10             [-1, 64, 7, 7]               0
       BasicBlock-11             [-1, 64, 7, 7]               0
           Conv2d-12             [-1, 64, 7, 7]          36,864
      BatchNorm2d-13             [-1, 64, 7, 7]             128
             ReLU-14             [-1, 6

In [93]:
run = wandb.init(project="GTN", entity="skylab", reinit=True)

In [94]:
wandb_logger = WandbLogger()
trainer = pl.Trainer(accelerator="gpu",
                     max_epochs=5,
                    logger=wandb_logger,    # W&B integration
                    log_every_n_steps=10,   # set the logging frequency
                    )
trainer.fit(LigModel(model_full_train), datamodule=cifar10)

  "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type     | Params
---------------------------------------
0 | model     | ResNet   | 11.2 M
1 | train_acc | Accuracy | 0     
2 | valid_acc | Accuracy | 0     
3 | test_acc  | Accuracy | 0     
---------------------------------------
5.1 K     Trainable params
11.2 M    Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [95]:
run.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆████████
train/acc_epoch,▁▇▇██
train/acc_step,▁▄▄▆▄▆▅▃▆▄▅▅▅▄▆▅▆▄▄▆▄▅▅▅▃▆▆▄▆▇▃▆▃▆▆█▃▆▃▇
train/loss_epoch,█▂▁▁▁
train/loss_step,█▅▄▄▅▃▄▄▄▃▃▃▃▃▃▃▂▃▃▃▃▃▃▄▅▃▃▄▃▃▄▄▅▃▃▁▄▂▄▂
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
validation/acc,▁▇███
validation/loss,█▂▁▁▁

0,1
epoch,5.0
train/acc_epoch,0.44537
train/acc_step,0.4375
train/loss_epoch,1.58481
train/loss_step,1.59318
trainer/global_step,6249.0
validation/acc,0.447
validation/loss,1.61734


In [None]:
wandb.finish()