In [1]:
%%capture
#Installing the requirements that Google Colab doesn't have
!pip install timm 
!pip install wandb --quiet
!pip install pytorch-lightning

#Unzipping it into the current folder
!unzip -qq MOD10A1_sierras.zip -d .
!unzip -qq MYD10A1_sierras.zip -d .
!unzip -qq copernicus_sierras2.zip -d .

In [2]:
#All of our imports
import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.nn import functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR

import torchvision
from torchvision import transforms as T
from torchvision.io import read_image

from torchmetrics import R2Score

import timm

from tqdm import tqdm_notebook as tqdm

import pytorch_lightning as pl
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import progress
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.callbacks import LearningRateMonitor

import sklearn
from sklearn.model_selection import StratifiedKFold

import wandb

#Checking out our list of trained models to choose from.
timm.list_models(pretrained=True)

In [3]:
class args:
  #Overall Args
  folder_name = "drive/MyDrive/snowcapstone team spring 2022/Modeling/"
  
  #Setting the number of CPU workers we are using
  num_workers = 4

  #Setting the seed so we can replicate
  seed = 1212

  #Toggle for whether or not we want our model pretrained on imagenet
  pretrained = True

  #Next we pick the model name with the appropriate shape, img size and output
  model_name1 = 'tf_efficientnet_b3_ns'
  model_shape1 = 1536
  model_name2 = 'tf_efficientnet_b4_ns'
  model_shape2 = 1792 #768 for swin small 1536 for swin large 1792 for efficientnet b4 768 for cait-m-36
  imagesize = 224
  num_classes = 1

  #Training Args
  train_batch_size = 32
  val_batch_size = 32
  test_batch_size = 32

  #Max epochs and number of folds
  max_epochs = 80
  n_splits = 2
  
  #Optimizer and Scheduler args
  loss = 'nn.BCEWithLogitsLoss'
  lr = 3e-4
  warmup_epochs = 5
  weight_decay = 3e-6
  eta_min = 0.000001
  n_accumulate = 1
  T_0 = 25
  T_max = 1000

  #Callback args
  #Minimum number amount of improvement to not trigger patience
  min_delta = 0.0
  #Number of epochs in a row to wait for improvement
  patience = 25

#Dataloader Args
loaderargs = {'num_workers' : args.num_workers, 'pin_memory': False, 'drop_last': False}
device = torch.device("cuda:0")

seed_everything(args.seed)


Global seed set to 1212


1212

In [4]:
#Reading in the data
df = pd.read_csv(f'traindf_allregion_4images.csv')

#Designating which columns are our metadata
feature_cols = [col for col in df.columns 
                if col not in 
                ['cell_id', 'date', 'MOD10A1_filelocations', 'MYD10A1_filelocations', 
                 'copernicus_filelocations', 'SWE','sentinel1a_filelocation','SWE_Scaled']]

#Min max scaling the meta data
scaler = sklearn.preprocessing.MinMaxScaler()
df[feature_cols] =  scaler.fit_transform(df[feature_cols])

#We will create a separate scaler for the targets so that we can transform them back and forth
target_scaler = sklearn.preprocessing.MinMaxScaler()
target_scaler.fit(np.array(df['SWE']).reshape(-1, 1))
df['SWE_Scaled'] = target_scaler.transform(np.array(df['SWE']).reshape(-1, 1))

tabluar_columns = len(feature_cols)

In [5]:
#Updating the modis file locations column to the correct location
df['MOD10A1_filelocations'] = df['MOD10A1_filelocations'].str.replace('/content/drive/MyDrive/snowcapstone team spring 2022/MODIS_Data/MOD10A1/', 'MOD10A1/')
df['MYD10A1_filelocations'] = df['MYD10A1_filelocations'].str.replace('/content/drive/MyDrive/snowcapstone team spring 2022/MODIS_Data/MYD10A1/', 'MYD10A1/')
df['copernicus_filelocations'] = df['copernicus_filelocations'].str.replace('/content/drive/MyDrive/snowcapstone team spring 2022/Copernicus_Data/', 'Copernicus_Data/')
df['sentinel1a_filelocation'] = df['sentinel1a_filelocation'].str.replace('/content/drive/MyDrive/snowcapstone team spring 2022/Sen1_Data/', 'Sen1_Data/')

In [6]:
df.head()

Unnamed: 0,cell_id,date,SWE,mean_inversed_swe,mean_local_swe,median_local_swe,max_local_swe,min_local_swe,mean_local_elevation,median_local_elevation,max_local_elevation,min_local_elevation,MOD10A1_filelocations,MYD10A1_filelocations,sentinel1a_filelocation,copernicus_filelocations,SWE_Scaled
0,00c4db22-a423-41a4-ada6-a8b1b04153a4,2016-01-05,10.6,0.126408,0.1588,0.133919,0.216199,0.144804,0.426831,0.435383,0.567589,0.407159,MOD10A1/00c4db22-a423-41a4-ada6-a8b1b04153a4_M...,MYD10A1/00c4db22-a423-41a4-ada6-a8b1b04153a4_M...,Sen1_Data/00c4db22-a423-41a4-ada6-a8b1b04153a4...,Copernicus_Data/00c4db22-a423-41a4-ada6-a8b1b0...,0.086957
1,018cf1a1-f945-4097-9c47-0c4690538bb5,2016-01-05,16.4,0.165288,0.112122,0.12822,0.184924,0.0,0.400201,0.422836,0.417655,0.360179,MOD10A1/018cf1a1-f945-4097-9c47-0c4690538bb5_M...,MYD10A1/018cf1a1-f945-4097-9c47-0c4690538bb5_M...,Sen1_Data/018cf1a1-f945-4097-9c47-0c4690538bb5...,Copernicus_Data/018cf1a1-f945-4097-9c47-0c4690...,0.134537
2,01be2cc7-ef77-4e4d-80ed-c4f8139162c3,2016-01-05,21.1,0.175357,0.122004,0.110412,0.173216,0.066723,0.679839,0.695107,0.73386,0.574944,MOD10A1/01be2cc7-ef77-4e4d-80ed-c4f8139162c3_M...,MYD10A1/01be2cc7-ef77-4e4d-80ed-c4f8139162c3_M...,Sen1_Data/01be2cc7-ef77-4e4d-80ed-c4f8139162c3...,Copernicus_Data/01be2cc7-ef77-4e4d-80ed-c4f813...,0.173093
3,02c3ec4a-8de4-4284-9ec1-5a942d3d098e,2016-01-05,2.0,0.022714,0.023837,0.028493,0.024058,0.014196,0.721118,0.708908,0.768116,0.744966,MOD10A1/02c3ec4a-8de4-4284-9ec1-5a942d3d098e_M...,MYD10A1/02c3ec4a-8de4-4284-9ec1-5a942d3d098e_M...,Sen1_Data/02c3ec4a-8de4-4284-9ec1-5a942d3d098e...,Copernicus_Data/02c3ec4a-8de4-4284-9ec1-5a942d...,0.016407
4,02cf33c2-c8e2-48b9-bf72-92506e97e251,2016-01-05,9.2,0.085734,0.104234,0.12822,0.113873,0.070982,0.775233,0.766625,0.852437,0.762864,MOD10A1/02cf33c2-c8e2-48b9-bf72-92506e97e251_M...,MYD10A1/02cf33c2-c8e2-48b9-bf72-92506e97e251_M...,Sen1_Data/02cf33c2-c8e2-48b9-bf72-92506e97e251...,Copernicus_Data/02cf33c2-c8e2-48b9-bf72-92506e...,0.075472


In [7]:
#Current dataset has some missing images. Instead point these paths to an image that is a full0 array
import os
df['sentinel1a_filelocation'] = [x if os.path.exists(x) else 'Sen1_Data/b98777af-0c7c-44f7-9c03-85d6d412856c_sentinel1_VV_2019365.jpg' for x in df['sentinel1a_filelocation']]

In [8]:
#Some I/O errors with some of these iamges in sentinel1a
for i in df['sentinel1a_filelocation']:
    try:
        a = read_image(i, mode = torchvision.io.image.ImageReadMode.RGB)
    except RuntimeError:
        print(i,'triggered RuntimeError')
        df.loc[df.sentinel1a_filelocation == i,'sentinel1a_filelocation'] = i.replace('.','(1).')
        print('replaced with: ',i.replace('.','(1).'))


Sen1_Data/5d62e7fe-926c-4f1d-8591-1f2885ddc4f8_sentinel1_VV_2018153.jpg triggered RuntimeError
replaced with:  Sen1_Data/5d62e7fe-926c-4f1d-8591-1f2885ddc4f8_sentinel1_VV_2018153(1).jpg
Sen1_Data/f8a873ef-2804-4f6b-babf-2d1fdbf7d4de_sentinel1_VV_2018153.jpg triggered RuntimeError
replaced with:  Sen1_Data/f8a873ef-2804-4f6b-babf-2d1fdbf7d4de_sentinel1_VV_2018153(1).jpg


In [9]:
df.loc[df.sentinel1a_filelocation == 'Sen1_Data/5d62e7fe-926c-4f1d-8591-1f2885ddc4f8_sentinel1_VV_2018153(1).jpg']

Unnamed: 0,cell_id,date,SWE,mean_inversed_swe,mean_local_swe,median_local_swe,max_local_swe,min_local_swe,mean_local_elevation,median_local_elevation,max_local_elevation,min_local_elevation,MOD10A1_filelocations,MYD10A1_filelocations,sentinel1a_filelocation,copernicus_filelocations,SWE_Scaled
38608,5d62e7fe-926c-4f1d-8591-1f2885ddc4f8,2018-06-02,0.0,0.000208,0.001114,0.000407,0.001925,0.0,0.51699,0.491844,0.54282,0.541387,MOD10A1/5d62e7fe-926c-4f1d-8591-1f2885ddc4f8_M...,MYD10A1/5d62e7fe-926c-4f1d-8591-1f2885ddc4f8_M...,Sen1_Data/5d62e7fe-926c-4f1d-8591-1f2885ddc4f8...,Copernicus_Data/5d62e7fe-926c-4f1d-8591-1f2885...,0.0


In [10]:
#Check for read_image error
for i in df['sentinel1a_filelocation']:
    try:
        a = read_image(i, mode = torchvision.io.image.ImageReadMode.RGB)
    except RuntimeError:
        print(i)

In [11]:
df.head()

Unnamed: 0,cell_id,date,SWE,mean_inversed_swe,mean_local_swe,median_local_swe,max_local_swe,min_local_swe,mean_local_elevation,median_local_elevation,max_local_elevation,min_local_elevation,MOD10A1_filelocations,MYD10A1_filelocations,sentinel1a_filelocation,copernicus_filelocations,SWE_Scaled
0,00c4db22-a423-41a4-ada6-a8b1b04153a4,2016-01-05,10.6,0.126408,0.1588,0.133919,0.216199,0.144804,0.426831,0.435383,0.567589,0.407159,MOD10A1/00c4db22-a423-41a4-ada6-a8b1b04153a4_M...,MYD10A1/00c4db22-a423-41a4-ada6-a8b1b04153a4_M...,Sen1_Data/b98777af-0c7c-44f7-9c03-85d6d412856c...,Copernicus_Data/00c4db22-a423-41a4-ada6-a8b1b0...,0.086957
1,018cf1a1-f945-4097-9c47-0c4690538bb5,2016-01-05,16.4,0.165288,0.112122,0.12822,0.184924,0.0,0.400201,0.422836,0.417655,0.360179,MOD10A1/018cf1a1-f945-4097-9c47-0c4690538bb5_M...,MYD10A1/018cf1a1-f945-4097-9c47-0c4690538bb5_M...,Sen1_Data/b98777af-0c7c-44f7-9c03-85d6d412856c...,Copernicus_Data/018cf1a1-f945-4097-9c47-0c4690...,0.134537
2,01be2cc7-ef77-4e4d-80ed-c4f8139162c3,2016-01-05,21.1,0.175357,0.122004,0.110412,0.173216,0.066723,0.679839,0.695107,0.73386,0.574944,MOD10A1/01be2cc7-ef77-4e4d-80ed-c4f8139162c3_M...,MYD10A1/01be2cc7-ef77-4e4d-80ed-c4f8139162c3_M...,Sen1_Data/b98777af-0c7c-44f7-9c03-85d6d412856c...,Copernicus_Data/01be2cc7-ef77-4e4d-80ed-c4f813...,0.173093
3,02c3ec4a-8de4-4284-9ec1-5a942d3d098e,2016-01-05,2.0,0.022714,0.023837,0.028493,0.024058,0.014196,0.721118,0.708908,0.768116,0.744966,MOD10A1/02c3ec4a-8de4-4284-9ec1-5a942d3d098e_M...,MYD10A1/02c3ec4a-8de4-4284-9ec1-5a942d3d098e_M...,Sen1_Data/b98777af-0c7c-44f7-9c03-85d6d412856c...,Copernicus_Data/02c3ec4a-8de4-4284-9ec1-5a942d...,0.016407
4,02cf33c2-c8e2-48b9-bf72-92506e97e251,2016-01-05,9.2,0.085734,0.104234,0.12822,0.113873,0.070982,0.775233,0.766625,0.852437,0.762864,MOD10A1/02cf33c2-c8e2-48b9-bf72-92506e97e251_M...,MYD10A1/02cf33c2-c8e2-48b9-bf72-92506e97e251_M...,Sen1_Data/b98777af-0c7c-44f7-9c03-85d6d412856c...,Copernicus_Data/02cf33c2-c8e2-48b9-bf72-92506e...,0.075472


In [12]:
#Making sure we don't have any any NAs 
print(feature_cols)
print(df.isna().sum())
df = df.fillna(0)
print(df.isna().sum())

['mean_inversed_swe', 'mean_local_swe', 'median_local_swe', 'max_local_swe', 'min_local_swe', 'mean_local_elevation', 'median_local_elevation', 'max_local_elevation', 'min_local_elevation']
cell_id                     0
date                        0
SWE                         0
mean_inversed_swe           0
mean_local_swe              0
median_local_swe            0
max_local_swe               0
min_local_swe               0
mean_local_elevation        0
median_local_elevation      0
max_local_elevation         0
min_local_elevation         0
MOD10A1_filelocations       0
MYD10A1_filelocations       0
sentinel1a_filelocation     0
copernicus_filelocations    0
SWE_Scaled                  0
dtype: int64
cell_id                     0
date                        0
SWE                         0
mean_inversed_swe           0
mean_local_swe              0
median_local_swe            0
max_local_swe               0
min_local_swe               0
mean_local_elevation        0
median_local_elev

In [13]:
df.dtypes

cell_id                      object
date                         object
SWE                         float64
mean_inversed_swe           float64
mean_local_swe              float64
median_local_swe            float64
max_local_swe               float64
min_local_swe               float64
mean_local_elevation        float64
median_local_elevation      float64
max_local_elevation         float64
min_local_elevation         float64
MOD10A1_filelocations        object
MYD10A1_filelocations        object
sentinel1a_filelocation      object
copernicus_filelocations     object
SWE_Scaled                  float64
dtype: object

In [14]:
#Datasets are how pytorch knows how to read in the data
class SWEDataset(torch.utils.data.Dataset):
    def __init__(self, df, test = False):
        self.df = df
        #First we must specify the path to the images
        self.MOD10A1_file_names = df['MOD10A1_filelocations'].values
        self.MYD10A1_file_names = df['MYD10A1_filelocations'].values
        self.copernicus_file_names = df['copernicus_filelocations'].values
        self.sentinel1_file_names = df['sentinel1a_filelocation'].values
        #The only transform we want to do right now is the resizing
        self._transform = T.Resize(size= (args.imagesize, args.imagesize))
        #We specify the tabular feature columns
        self.meta = df[feature_cols].values
        #Now we specify the targets
        self.targets = df['SWE_Scaled'].values
        #Finally we specify if this is training or test
        self.test = test
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        #Get the image, scale it to between 0-1 and resize it
        MOD10A1_img_path = self.MOD10A1_file_names[index]
        MOD10A1_img = read_image(MOD10A1_img_path, mode = torchvision.io.image.ImageReadMode.RGB) / 255
        MOD10A1_img = self._transform(MOD10A1_img)

        MYD10A1_img_path = self.MYD10A1_file_names[index]
        MYD10A1_img = read_image(MYD10A1_img_path, mode = torchvision.io.image.ImageReadMode.RGB) / 255
        MYD10A1_img = self._transform(MYD10A1_img)

        copernicus_img_path = self.copernicus_file_names[index]
        copernicus_img = read_image(copernicus_img_path, mode = torchvision.io.image.ImageReadMode.RGB) / 255
        copernicus_img = self._transform(copernicus_img)
        
        sentinel1_img_path = self.sentinel1_file_names[index]
        sentinel1_img = read_image(sentinel1_img_path, mode = torchvision.io.image.ImageReadMode.RGB) / 255
        sentinel1_img = self._transform(sentinel1_img)

        #Pull in the features for our batch
        meta = self.meta[index, :]
        
        #Specify the target based on whether this is training or test
        if self.test:
          target = 0
        else:
          target = self.targets[index]
            
        return MOD10A1_img, MYD10A1_img, copernicus_img, sentinel1_img, target, meta

In [15]:
#Pytorch Lightning Requires that the dataset be formatted as a module
class SWEDataModule(pl.LightningDataModule):
    def __init__(self, traindf, valdf,args, loaderargs):
        super().__init__()
        #Import our training and validation set, which we will define later
        self._train_df = traindf
        self._val_df = valdf

        #Makesure we bring in our args so we can use them
        self.args = args
        self.loaderargs = loaderargs

    #Building the datasets
    def __create_dataset(self, train=True):
        if train == 'train':
          return SWEDataset(self._train_df)
        else:
          return SWEDataset(self._val_df)

    #Using the datasets to return a dataloader
    def train_dataloader(self):
        SWE_train = self.__create_dataset("train")
        return DataLoader(SWE_train, **self.loaderargs, batch_size=self.args.train_batch_size)

    def val_dataloader(self):
        SWE_val = self.__create_dataset("val")
        return DataLoader(SWE_val, **self.loaderargs, batch_size=self.args.val_batch_size)
    

In [16]:
def get_default_transforms():
    transform = {
        "train": T.Compose(
            [
                #T.RandomHorizontalFlip(),
                #T.RandomVerticalFlip(),
                #T.RandomAffine(15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
                T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
                T.ConvertImageDtype(torch.float),
                T.Normalize(mean = (0.485, 0.456, 0.406), 
                            std = (0.229, 0.224, 0.225))
                
            ]
        ),
        "val": T.Compose(
            [
                T.ConvertImageDtype(torch.float),
                T.Normalize(mean = (0.485, 0.456, 0.406), 
                            std = (0.229, 0.224, 0.225))
            ]
        ),
    }
    return transform
  

def mixup(x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor, x4:torch.Tensor, y: torch.Tensor, 
          z = torch.Tensor, alpha: float = 1.0):
    assert alpha > 0, "alpha should be larger than 0"
    assert x1.size(0) > 1 and x2.size(0) > 1, "Mixup cannot be applied to a single instance."

    lam = np.random.beta(alpha, alpha)
    rand_index = torch.randperm(x1.size()[0])
    mixed_x1 = lam * x1 + (1 - lam) * x1[rand_index, :]
    mixed_x2 = lam * x2 + (1 - lam) * x2[rand_index, :]
    mixed_x3 = lam * x3 + (1 - lam) * x3[rand_index, :]
    mixed_x4 = lam * x4 + (1 - lam) * x4[rand_index, :]
    mixed_meta = lam * z + (1 - lam) * z[rand_index, :]
    target_a, target_b = y, y[rand_index]
    return mixed_x1, mixed_x2, mixed_x3, mixed_x4, mixed_meta, target_a, target_b,  lam

In [17]:
#Use this to find the model_shape attribute when changing models

x = torch.randn(1,3,args.imagesize,args.imagesize)
model = timm.create_model(args.model_name2, #                                       pretrained=args.pretrained, 
                                      num_classes=0,
                                      in_chans = 3)
model(x).shape

torch.Size([1, 1792])

In [18]:
class SWEModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.args = args
        self.scaler = target_scaler
        self.tabular_columns = tabluar_columns
        #Image Models
        self.model1 = timm.create_model(args.model_name2, 
                                       pretrained=args.pretrained, 
                                       num_classes=0,
                                       in_chans = 12,
                                       global_pool='')
        
        #self.fc1 = nn.Linear(args.model_shape1*2 + args.model_shape2, 768)
        self.fc2 = nn.Linear(585, 384)
        self.fc3 = nn.Linear(384, 96)
        self.fc4 = nn.Linear(96, args.num_classes)
        self.dropout = nn.Dropout(p=0.3)
        self.relu = nn.ReLU()
        self._criterion = eval(self.args.loss)()
        self.transform = get_default_transforms()
        
        self.cv1 = nn.Conv2d(args.model_shape2, 256 ,kernel_size=3,stride=1,padding=1)
        self.bn1 = nn.BatchNorm2d(256)
        self.cv2 = nn.Conv2d(256, 64 ,kernel_size=3,stride=1,padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.mp1 = nn.MaxPool2d(kernel_size=1,stride=1)
        self.mp2 = nn.MaxPool2d(kernel_size=2,stride=2)
        
        self.trainr2 = R2Score()
        self.valr2 = R2Score()

    def forward(self, features1, features2, features3, features4, meta):
        
        
        features = torch.cat((features1,features2,features3,features4),dim=1)
        #print('features',features.size())
        
        #Image Model
        features = self.model1(features)                 
        features = self.relu(features)
        features = self.dropout(features)

        #Convolution without pooling
        features = self.cv1(features)
        features = self.bn1(features)
        features = self.relu(features)
        features = self.dropout(features)
        features = self.mp1(features)
        features = self.cv2(features)
        features = self.bn2(features)
        features = self.relu(features)
        features = self.mp2(features)
        
        features = features.view(features.size(0), -1)
        #print('mp1',features.size())
        
        #Concatenating the meta data
        features = torch.cat([features, meta], dim=1)
        
        #print(features.size())
        
        #Final fully connected layers
        features = self.fc2(features)
        features = self.relu(features)
        
        features = self.fc3(features)
        features = self.relu(features)
        
        output = self.fc4(features)           
        return output


    def __share_step(self, batch, mode):
        MOD10A1_img, MYD10A1_img, copernicus_img, sentinel1_img, labels, meta = batch
        labels = labels.float()
        meta = meta.float()
        MOD10A1_img = self.transform[mode](MOD10A1_img)
        MYD10A1_img = self.transform[mode](MYD10A1_img)
        copernicus_img = self.transform[mode](copernicus_img)
        sentinel1_img = self.transform[mode](sentinel1_img)

        rand_index = torch.rand(1)[0]
        
        #This is a mixup function
        if rand_index < 0.5 and mode == 'train':
            MOD10A1_mixed, MYD10A1_mixed, copernicus_mixed, sentinel1_mixed, \
            mixed_meta, target_a, target_b, lam = mixup(MOD10A1_img, 
                                                          MYD10A1_img,
                                                          copernicus_img, sentinel1_img,
                                                          labels, meta, alpha=0.5)
            logits = self.forward(MOD10A1_mixed, MYD10A1_mixed, copernicus_mixed, sentinel1_mixed, mixed_meta).squeeze(1)
            loss = self._criterion(logits, target_a) * lam + \
                (1 - lam) * self._criterion(logits, target_b)

        else:  
          logits = self.forward(MOD10A1_img, MYD10A1_img, copernicus_img, sentinel1_img, meta).squeeze(1)
          loss = self._criterion(logits, labels)

        pred = torch.from_numpy(self.scaler \
            .inverse_transform(np.array(logits.sigmoid().detach().cpu()) \
            .reshape(-1, 1)))
        labels = torch.from_numpy(self.scaler \
            .inverse_transform(np.array(labels.detach().cpu()) \
            .reshape(-1, 1)))
        
        '''
        #This is random noise
        elif rand_index > 0.8 and mode == 'train':
            images = images + (torch.randn(images.size(0),3,args.imagesize,args.imagesize, 
                                           dtype = torch.float, device = device)*10)/100
            logits = self.forward(images, meta).squeeze(1)
            loss = self._criterion(logits, labels)
        '''

        return loss, pred, labels

    def training_step(self, batch, batch_idx):
        loss, pred, labels = self.__share_step(batch, 'train')
        self.trainr2(pred.cuda(),labels.cuda())
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        return {'loss': loss, 'pred': pred, 'labels': labels}



    def validation_step(self, batch, batch_idx):
        loss, pred, labels = self.__share_step(batch, 'val')
        self.valr2(pred.cuda(),labels.cuda())
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return {'pred': pred, 'labels': labels}


    def training_epoch_end(self, outputs):
        self.log('train_r2_epoch',self.trainr2)
        self.__share_epoch_end(outputs, 'train')

    def validation_epoch_end(self, outputs):
        self.log('val_r2_epoch',self.valr2)
        self.__share_epoch_end(outputs, 'val')

        
    def __share_epoch_end(self, outputs, mode):
        preds = []
        labels = []
        for out in outputs:
            pred, label = out['pred'], out['labels']
            preds.append(pred)
            labels.append(label)
        preds = torch.cat(preds)
        labels = torch.cat(labels)
        metrics = torch.sqrt(((labels - preds) ** 2).mean())
        self.log(f'{mode}_RMSE', metrics)    


    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=args.lr, weight_decay = args.weight_decay)
        
        return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": CosineAnnealingLR(optimizer, T_max = args.T_max, eta_min= args.eta_min),
            "interval": "step",
            "monitor": "train_loss",
            "frequency": 1}
            }

In [None]:
Kfolds = StratifiedKFold(n_splits=args.n_splits, shuffle=True, 
                         random_state = args.seed)

num_bins = int(np.ceil(2*((len(df))**(1./3))))


df['bins'] = pd.cut(df['SWE'], bins=num_bins, labels=False)


for fold, (train_idx, val_idx) in enumerate(Kfolds.split(df["cell_id"], df["bins"])):
    traindf = df.loc[train_idx].reset_index(drop=True)
    valdf = df.loc[val_idx].reset_index(drop=True)

    model = SWEModel()

    #Callbacks
    early_stop_callback = EarlyStopping(monitor="val_RMSE", min_delta=args.min_delta, patience=args.patience, 
                                        verbose=False, mode="min")
    progressbar = TQDMProgressBar(refresh_rate = 10)
    checkpoint_callback = ModelCheckpoint(dirpath='/home/ubuntu/snowcap/weights', 
                                          filename= f"{fold}best_weights", save_top_k=1, monitor="val_RMSE")
    lr_monitor = LearningRateMonitor(logging_interval='step')

    wandb_logger = WandbLogger(name='12ch_ENb4_CosLR-higherLR',project = "SWE_test", entity = "snowcastshowdown", job_type='train', log_model = 'all')

    wandb_logger.watch(model)

    trainer = pl.Trainer(max_epochs=args.max_epochs, 
                        gpus=1, 
                        logger=wandb_logger,
                        callbacks=[early_stop_callback, 
                                    progressbar, 
                                    checkpoint_callback,
                                    lr_monitor])

    SWE_Datamodule = SWEDataModule(traindf, valdf, args = args, loaderargs = loaderargs)

    trainer.fit(model, SWE_Datamodule)

    wandb.finish()
    
    del model
    torch.cuda.empty_cache()

[34m[1mwandb[0m: Currently logged in as: [33msnowcastshowdown[0m (use `wandb login --relogin` to force relogin)


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name       | Type              | Params
--------------------------------------------------
0  | model1     | EfficientNet      | 17.6 M
1  | fc2        | Linear            | 225 K 
2  | fc3        | Linear            | 37.0 K
3  | fc4        | Linear            | 97    
4  | dropout    | Dropout           | 0     
5  | relu       | ReLU              | 0     
6  | _criterion | BCEWithLogitsLoss | 0     
7  | cv1        | Conv2d            | 4.1 M 
8  | bn1        | BatchNorm2d       | 512   
9  | cv2        | Conv2d            | 147 K 
10 | bn2        | BatchNorm2d       | 128   
11 | mp1        | MaxPool2d         | 0     
12 | mp2        | MaxPool2d         | 0     
13 | trainr2    | R2Score           | 0     
14 | valr2      | R2Score           | 0  

Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 1212


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [24]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

In [14]:
print(torch.cuda.memory_summary(device=None, abbreviated=False))

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 1            |        cudaMalloc retries: 1         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   13255 MB |   13255 MB |   36303 MB |   23047 MB |
|       from large pool |   13183 MB |   13183 MB |   36047 MB |   22863 MB |
|       from small pool |      72 MB |      72 MB |     255 MB |     183 MB |
|---------------------------------------------------------------------------|
| Active memory         |   13255 MB |   13255 MB |   36303 MB |   23047 MB |
|       from large pool |   13183 MB |   13183 MB |   36047 MB |   22863 MB |
|       from small pool |      72 MB |      72 MB |     255 MB |     183 MB |
|---------------------------------------------------------------

In [13]:
torch.cuda.empty_cache()