In [78]:
import torch

import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import random_split, DataLoader, Dataset
from pytorch_lightning.loggers import TensorBoardLogger


import utils.resnet as resnet 

In [79]:
hyperparameters = { 'batch size': 64,
                    'val ratio': 0.1,
                    'epochs': 100,
                    'lr': 0.0000095,
                    'lr decay': 0.25,
                    'lr decay threshold': 0.05,
                    'lr warming up period': 50,
                    'weight decay': 0.01,
                    'dataset path': './dataset'}

In [80]:
class dataset(Dataset):
    def __init__(self, dataset_path):
        self.images = torch.load(f'{dataset_path}/dataset_images')
        self.dnas = torch.load(f'{dataset_path}/dataset_dna')
        self.n_samples = len(self.images) 

    def __getitem__(self, index):
        image, i = self.images[index]
        return image, self.dnas[i]

    def __len__(self):
        return self.n_samples

In [81]:
class Net(pl.LightningModule):

    def __init__(self, model):
        super(Net, self).__init__()
        self.model = model

        self.learning_rate = hyperparameters['lr']
        self.weight_decay = hyperparameters['weight decay']



    def mse(self, x, y):
        return  F.mse_loss(x, y)


    def training_step(self, batch, batch_idx):
        images, dnas = batch
        

        predictions = self.model(images)

        loss = self.mse(predictions, dnas)       
        return {'loss': loss}

    def trainning_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        self.log('train_loss', avg_loss, prog_bar=True, on_epoch=True)
        return {'loss': avg_loss}

    def val_step(self, batch, batch_idx):
        images, dnas = batch
        

        predictions = self.model(images)

        loss = self.mse(predictions, dnas)       
        return {'val_loss': loss}

    def val_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        self.log('val_loss', avg_loss, prog_bar=True, on_epoch=True)
        return {'val_loss': avg_loss}

   
    def configure_optimizers(self):
        optimizer = torch.optim.RMSprop(self.model.parameters(), lr = self.learning_rate, weight_decay = self.weight_decay)
        return optimizer       

                                                   

In [82]:
class Dataset(pl.LightningDataModule):
    def __init__(self, batch_size, val_ratio, dataset_path):
        self.val_ratio = val_ratio
        self.batch_size = batch_size
        self.dataset_path = dataset_path

    def setup(self, stage=None):
        data = dataset(self.dataset_path)
        valid_idx = int(len(data)*self.val_ratio)
        train_idx = len(data)-valid_idx

        self.train_data, self.val_data = random_split(data, [train_idx, valid_idx])
        

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size)


In [76]:
data = Dataset(hyperparameters['batch size'], hyperparameters['val ratio'], hyperparameters['dataset path'])


model = Net(resnet.ResNet50(img_channel=3, num_features=101))
trainer = Trainer(gpus=1, max_epochs=hyperparameters['epochs'])

  torch.nn.init.xavier_uniform(layer)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [77]:
trainer.fit(model, data)

  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 23.7 M
---------------------------------
23.7 M    Trainable params
0         Non-trainable params
23.7 M    Total params
94.860    Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Epoch 34: 100%|██████████| 1/1 [00:03<00:00,  3.96s/it, loss=21.7, v_num=19]    
