In [1]:
#!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
#!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

In [None]:
!pip install --upgrade pip
!pip install --upgrade albumentations
!pip install neptune-client
!pip install pytorch-lightning
!pip install pytorch_ranger

!pip install --upgrade pip
!pip install --upgrade --force-reinstall --no-deps kaggle
!pip install timm
!pip install yacs

In [None]:
!mkdir /root/.kaggle/
!cp -avr /content/drive/MyDrive/kaggle.json /root/.kaggle/

!kaggle competitions download -c cassava-leaf-disease-classification
!git clone https://github.com/GenoM87/cassava_leaf.git
!mkdir cassava_leaf/data
!mkdir cassava_leaf/experiments
!unzip cassava-leaf-disease-classification.zip -d cassava_leaf/data

!python cassava_leaf/src/create_folds.py

In [2]:
import os
os.chdir('../src')

In [8]:
import numpy as np
import pandas as pd

import sys, os, time, logging, datetime, random
from pathlib import Path

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

from config import _C as cfg
from models.create_model import CustomNet

from data_builder import build_valid_loader, build_train_loader
from models.optimizer import make_optimizer
from models.scheduler import make_scheduler

#TODO: provare ad usare questo
from models.loss import BiTemperedLogisticLoss

In [4]:
#Creo lla directory per l'esperimento
path_exp = os.path.join(
    cfg.PROJECT_DIR, 'experiments', cfg.MODEL.NAME, str(datetime.date.today())
)

Path(path_exp).mkdir(parents=True, exist_ok=True)

In [9]:
def set_seed(seed=2004):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    
set_seed(cfg.RANDOM_STATE)

In [None]:
class hmapModel(pl.LightningModule):

  def __init__(self):
    super().__init__()
    
    self.model = CustomNet(
        cfg
    )
    self.train_accuracy = pl.metrics.Accuracy()
    self.valid_accuracy = pl.metrics.Accuracy()
    self.loss_fn = BiTemperedLogisticLoss(
        t1=cfg.SOLVER.BIT_T1,
        t2=cfg.SOLVER.BIT_T2,
        smoothing=cfg.SOLVER.SMOOTHING_LOSS 
    )
        
  def forward(self, x):
      return self.model(x)
  
  def training_step(self, batch, batch_idx):
      x, y = batch
      y_hat = self.model(x)
      loss = self.loss_fn(y_hat, y)
      #loss = symmetric_lovasz(y_hat, y)
      train_dice = dice_fn(y_hat, y)
      self.log('train_loss', loss)
      self.log(
          'train_dice', 
          train_dice.item(), 
          on_step=False, 
          on_epoch=True, 
          logger=True
      )
      return {'loss': loss, 'train_dice': train_dice.item()}

  def configure_optimizers(self):
      optimizer = torch.optim.Adam(self.parameters(), lr=CONFIG['lr'])
      scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=4)
      return {
       'optimizer': optimizer,
       'lr_scheduler': scheduler,
       'monitor': 'val_loss'
      }

  def train_dataloader(self):
      dataset = HuBMAPDataset(
          train_ids,
          tfms=get_train_aug(CONFIG['size'])
      )
      loader = DataLoader(
          dataset,
          batch_size=CONFIG['train_batch_size'], 
          shuffle=True,
          num_workers=CONFIG['num_workers']
      )
      return loader
  
  def val_dataloader(self):
      dataset = HuBMAPDataset(
          val_ids,
          tfms=get_valid_aug(CONFIG['size'])
      )
      loader = DataLoader(
          dataset,
          batch_size=CONFIG['val_batch_size'], 
          shuffle=False,
          num_workers=CONFIG['num_workers']
      )
      return loader

  def validation_step(self, batch, batch_idx):
      x, y = batch
      y_hat = self.model(x).squeeze(1)
      val_loss = loss_fn(y_hat, y.float())
      val_dice = dice_fn(y_hat, y)
      self.log('val_loss', val_loss)
      self.log(
          'valid_dice', 
          val_dice.item(), 
          on_step=False, 
          on_epoch=True, 
          logger=True
      )
      return {'val_loss': val_loss, 'valid_dice': val_dice.item()}