In [1]:
!pip install timm 
!pip install wandb --quiet
!pip install pytorch-lightning

Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com
Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com




In [2]:
#All of our imports
import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.nn import functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR

import torchvision
from torchvision import transforms as T
from torchvision.io import read_image

from torchmetrics import R2Score

import timm

from tqdm import tqdm_notebook as tqdm

import pytorch_lightning as pl
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import progress
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.callbacks import LearningRateMonitor

import sklearn
from sklearn.model_selection import StratifiedKFold

import wandb

In [3]:
#Reading in the data
df = pd.read_csv(f'/home/ubuntu/SnowData/Final_CNN_Dataframe.csv')

#Designating which columns are our metadata
feature_cols = [col for col in df.columns 
                if col not in 
                ['cell_id', 'date', 'MOD10A1_filelocations', 'MYD10A1_filelocations', 
                 'copernicus_filelocations', 'SWE','sentinel1_filelocation','sentinel2a_filelocation',
                 'sentinel2b_filelocation','SWE_Scaled',
                 'mean_inversed_swe', 'mean_local_swe', 'median_local_swe', 'max_local_swe', 'min_local_swe',
                 'mean_local_elevation', 'median_local_elevation', 'max_local_elevation', 'min_local_elevation']]

#Min max scaling the meta data
scaler = sklearn.preprocessing.MinMaxScaler()
df[feature_cols] =  scaler.fit_transform(df[feature_cols])

#We will create a separate scaler for the targets so that we can transform them back and forth
target_scaler = sklearn.preprocessing.MinMaxScaler()
target_scaler.fit(np.array(df['SWE']).reshape(-1, 1))
df['SWE_Scaled'] = target_scaler.transform(np.array(df['SWE']).reshape(-1, 1))

tabluar_columns = len(feature_cols)

In [4]:
#Tmp drop two rows until better solution i s found
df = df.drop([56955,82314])
df.reset_index(inplace=True,drop=True)

In [5]:
#Join weather data
a = pd.read_csv('/home/ubuntu/SnowData/Unique_CellIDs_byDate__ASO_50M_SWE_USCALB__with_HRRR_TMP_surface_12h.csv',index_col=0)

b = pd.read_csv('/home/ubuntu/SnowData/Unique_CellIDs_byDate__ASO_50M_SWE_USCALB__with_HRRR_PRATE_surface_12h.csv',index_col=0)

weather = pd.merge(a, b,  how='left', on = ['index','cell_id','geometry','date','month_year'])

In [6]:
year = []
for i in df['date'].values:
    if pd.to_datetime(i).strftime('%Y') in year:
        continue
    else:
        year.append(pd.to_datetime(i).strftime('%Y'))

In [7]:
year

['2017', '2019', '2016', '2018']

In [8]:
df[pd.to_datetime(df['date']).dt.year == int('2019')]

Unnamed: 0,cell_id,date,SWE,mean_inversed_swe,mean_local_swe,median_local_swe,max_local_swe,min_local_swe,mean_local_elevation,median_local_elevation,...,MOD10A1_Albedo,MOD10A1_NDSI,MYD10A1_SnowCover,MYD10A1_Albedo,MYD10A1_NDSI,copernicus_filelocations,sentinel1_filelocation,sentinel2a_filelocation,sentinel2b_filelocation,SWE_Scaled
100,ASO_50M_SWE_USCAKC_7,2019-04-28,73.029772,27.858266,30.746571,30.051429,47.944286,19.234286,2380.488,2301.24,...,0.176204,0.830138,0.806131,0.339499,0.873714,/home/ubuntu/SnowData/CopernicusData/ASO_50M_S...,/home/ubuntu/SnowData/Sen1_Data_poly/ASO_50M_S...,/home/ubuntu/SnowData/Sen2_DataA_poly/ASO_50M_...,/home/ubuntu/SnowData/Sen2_DataB_poly/ASO_50M_...,0.510394
101,ASO_50M_SWE_USCAKC_8,2019-04-28,55.866164,27.913840,30.746571,30.051429,47.944286,19.234286,2380.488,2301.24,...,0.161360,0.793747,0.342564,0.128750,0.792446,/home/ubuntu/SnowData/CopernicusData/ASO_50M_S...,/home/ubuntu/SnowData/Sen1_Data_poly/ASO_50M_S...,/home/ubuntu/SnowData/Sen2_DataA_poly/ASO_50M_...,/home/ubuntu/SnowData/Sen2_DataB_poly/ASO_50M_...,0.390440
102,ASO_50M_SWE_USCAKC_9,2019-04-28,35.951281,27.954850,30.746571,30.051429,47.944286,19.234286,2380.488,2301.24,...,0.128699,0.736110,0.353559,0.159386,0.655651,/home/ubuntu/SnowData/CopernicusData/ASO_50M_S...,/home/ubuntu/SnowData/Sen1_Data_poly/ASO_50M_S...,/home/ubuntu/SnowData/Sen2_DataA_poly/ASO_50M_...,/home/ubuntu/SnowData/Sen2_DataB_poly/ASO_50M_...,0.251258
103,ASO_50M_SWE_USCAKC_10,2019-04-28,14.504536,27.985398,30.746571,30.051429,47.944286,19.234286,2380.488,2301.24,...,0.120098,0.543217,0.232479,0.114443,0.640034,/home/ubuntu/SnowData/CopernicusData/ASO_50M_S...,/home/ubuntu/SnowData/Sen1_Data_poly/ASO_50M_S...,/home/ubuntu/SnowData/Sen2_DataA_poly/ASO_50M_...,/home/ubuntu/SnowData/Sen2_DataB_poly/ASO_50M_...,0.101370
104,ASO_50M_SWE_USCAKC_11,2019-04-28,0.770748,28.009257,30.746571,30.051429,47.944286,19.234286,2380.488,2301.24,...,0.032358,0.408719,0.059303,0.033972,0.245388,/home/ubuntu/SnowData/CopernicusData/ASO_50M_S...,/home/ubuntu/SnowData/Sen1_Data_poly/ASO_50M_S...,/home/ubuntu/SnowData/Sen2_DataA_poly/ASO_50M_...,/home/ubuntu/SnowData/Sen2_DataB_poly/ASO_50M_...,0.005387
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95194,ASO_50M_SWE_USCATE_3504,2019-05-03,41.849002,24.757985,28.240571,27.872857,40.472857,16.915714,2755.392,2621.28,...,0.000000,0.854219,0.904700,0.590823,0.938346,/home/ubuntu/SnowData/CopernicusData/ASO_50M_S...,/home/ubuntu/SnowData/Sen1_Data_poly/ASO_50M_S...,/home/ubuntu/SnowData/Sen2_DataA_poly/ASO_50M_...,/home/ubuntu/SnowData/Sen2_DataB_poly/ASO_50M_...,0.292476
95195,ASO_50M_SWE_USCATE_3505,2019-05-03,24.609933,24.908190,28.240571,27.872857,40.472857,16.915714,2755.392,2621.28,...,0.000000,0.864160,0.965664,0.631072,0.978420,/home/ubuntu/SnowData/CopernicusData/ASO_50M_S...,/home/ubuntu/SnowData/Sen1_Data_poly/ASO_50M_S...,/home/ubuntu/SnowData/Sen2_DataA_poly/ASO_50M_...,/home/ubuntu/SnowData/Sen2_DataB_poly/ASO_50M_...,0.171995
95196,ASO_50M_SWE_USCATE_3506,2019-05-03,41.967672,25.059875,28.240571,27.872857,40.472857,16.915714,2755.392,2621.28,...,0.000000,0.845639,0.939891,0.616825,0.963127,/home/ubuntu/SnowData/CopernicusData/ASO_50M_S...,/home/ubuntu/SnowData/Sen1_Data_poly/ASO_50M_S...,/home/ubuntu/SnowData/Sen2_DataA_poly/ASO_50M_...,/home/ubuntu/SnowData/Sen2_DataB_poly/ASO_50M_...,0.293305
95197,ASO_50M_SWE_USCATE_3507,2019-05-03,37.977934,25.209928,28.240571,27.872857,40.472857,16.915714,2755.392,2621.28,...,0.000000,0.817995,0.893079,0.536370,0.929036,/home/ubuntu/SnowData/CopernicusData/ASO_50M_S...,/home/ubuntu/SnowData/Sen1_Data_poly/ASO_50M_S...,/home/ubuntu/SnowData/Sen2_DataA_poly/ASO_50M_...,/home/ubuntu/SnowData/Sen2_DataB_poly/ASO_50M_...,0.265422


In [23]:
class args:
    #Overall Args
    folder_name = "/home/ubuntu/SnowData"

    #Keep track of features used in wandb
    features = feature_cols

    #Setting the number of CPU workers we are using
    num_workers = 4

    #Setting the seed so we can replicate
    seed = 1212

    #Toggle for whether or not we want our model pretrained on imagenet
    pretrained = True

    #Next we pick the model name with the appropriate shape, img size and output
    model_name1 = 'mixnet_s'
    model_shape1 = 1536
    model_name2 = 'tf_efficientnet_b2_ns'
    model_shape2 = 1408 #768 for swin small 1536 for swin large 1792 for efficientnet b4 768 for cait-m-36
    imagesize = 224
    num_classes = 1
    img_channels = 3

    #LSTM variables
    lstm_hidden = 64
    lstm_layers = 1
    lstm_seqlen = 10

    #Training Args
    train_batch_size = 24
    val_batch_size = 24
    test_batch_size = 24

    #Max epochs and number of folds
    max_epochs = 100
    n_splits = 2

    #Optimizer and Scheduler args
    loss = 'nn.BCEWithLogitsLoss'
    lr = 3e-4
    warmup_epochs = 5
    weight_decay = 3e-6
    eta_min = 0.000001
    n_accumulate = 1
    T_0 = 25
    T_max = 2000

    #Callback args
    #Minimum number amount of improvement to not trigger patience
    min_delta = 0.0
    #Number of epochs in a row to wait for improvement
    patience = 30

#Dataloader Args
loaderargs = {'num_workers' : args.num_workers, 'pin_memory': False, 'drop_last': False}
device = torch.device("cuda:0")

seed_everything(args.seed)

Global seed set to 1212


1212

In [36]:
#Datasets are how pytorch knows how to read in the data
class SWEDataset(torch.utils.data.Dataset):
    def __init__(self, df,ts, test = False, seq_len = 10):
        self.df = df
        self.seq_len = seq_len
        #First we must specify the path to the images
        #self.MOD10A1_file_names = df['MOD10A1_filelocations'].values
        #self.MYD10A1_file_names = df['MYD10A1_filelocations'].values
        self.copernicus_file_names = df['copernicus_filelocations'].values
        self.sentinel1_file_names = df['sentinel1_filelocation'].values
        self.sentinel2a_file_names = df['sentinel2a_filelocation'].values
        self.sentinel2b_file_names = df['sentinel2b_filelocation'].values
        #Variables to query time series and output
        self.cell_id = df['cell_id'].values
        self.date = df['date'].values
        self.timeseries = ts
        #The only transform we want to do right now is the resizing
        self._transform = T.Resize(size= (args.imagesize, args.imagesize))
        #We specify the tabular feature columns
        self.meta = df[feature_cols].values
        #Now we specify the targets
        self.targets = df['SWE_Scaled'].values
        #Finally we specify if this is training or test
        self.test = test
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        #Get the image, scale it to between 0-1 and resize it
        copernicus_img_path = self.copernicus_file_names[index]
        copernicus_img = read_image(copernicus_img_path,mode = torchvision.io.image.ImageReadMode.RGB) / 255
        copernicus_img = self._transform(copernicus_img)
        
        sentinel1_img_path = self.sentinel1_file_names[index]
        sentinel1_img = read_image(sentinel1_img_path, mode = torchvision.io.image.ImageReadMode.RGB) / 255
        sentinel1_img = self._transform(sentinel1_img)
        
        sentinel2a_img_path = self.sentinel2a_file_names[index]
        sentinel2a_img = read_image(sentinel2a_img_path, mode = torchvision.io.image.ImageReadMode.RGB) / 255
        sentinel2a_img = self._transform(sentinel2a_img)
        
        sentinel2b_img_path = self.sentinel2b_file_names[index]
        sentinel2b_img = read_image(sentinel2b_img_path, mode = torchvision.io.image.ImageReadMode.RGB) / 255
        sentinel2b_img = self._transform(sentinel2b_img)

        #Pull from weather data and generate time-series
        date_range = pd.date_range(end=self.date[index], periods=self.seq_len)
        ts = []
        for date in date_range:
            query = self.timeseries.loc[(self.timeseries['cell_id']==self.cell_id[index]) & (self.timeseries['date']==self.date[index])]
            if not query.empty:
                ts.append([list(query['HRRR_TMP_surface_12h'])[0],list(query['HRRR_PRATE_surface_12h'])[0]])
            else:
                ts.append([np.nan,np.nan])
        ts = torch.tensor(ts)
        
        
        #Pull in the features for our batch
        meta = self.meta[index, :]
        
        #Specify the target based on whether this is training or test
        if self.test:
          target = 0
        else:
          target = self.targets[index]
            
        return copernicus_img, sentinel1_img, sentinel2a_img, sentinel2b_img, target, meta , ts

In [37]:
#Pytorch Lightning Requires that the dataset be formatted as a module
class SWEDataModule(pl.LightningDataModule):
    def __init__(self, traindf, valdf,ts,args, loaderargs):
        super().__init__()
        #Import our training and validation set, which we will define later
        self._train_df = traindf
        self._val_df = valdf
        self.ts = ts

        #Makesure we bring in our args so we can use them
        self.args = args
        self.loaderargs = loaderargs

    #Building the datasets
    def __create_dataset(self, train=True):
        if train == 'train':
          return SWEDataset(self._train_df,self.ts)
        else:
          return SWEDataset(self._val_df, self.ts)

    #Using the datasets to return a dataloader
    def train_dataloader(self):
        SWE_train = self.__create_dataset("train")
        return DataLoader(SWE_train, **self.loaderargs, batch_size=self.args.train_batch_size)

    def val_dataloader(self):
        SWE_val = self.__create_dataset("val")
        return DataLoader(SWE_val, **self.loaderargs, batch_size=self.args.val_batch_size)

In [38]:
def get_default_transforms():
    transform = {
        "train": T.Compose(
            [
                #T.RandomHorizontalFlip(),
                #T.RandomVerticalFlip(),
                #T.RandomAffine(15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
                #T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
                T.ConvertImageDtype(torch.float),
                T.Normalize(mean = (0.485, 0.456, 0.406), 
                            std = (0.229, 0.224, 0.225))
                
            ]
        ),
        "val": T.Compose(
            [
                T.ConvertImageDtype(torch.float),
                T.Normalize(mean = (0.485, 0.456, 0.406), 
                            std = (0.229, 0.224, 0.225))
            ]
        ),
    }
    return transform
  

def mixup(x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor,x4: torch.Tensor,
          y: torch.Tensor, 
          z = torch.Tensor, alpha: float = 1.0):
    assert alpha > 0, "alpha should be larger than 0"
    assert x1.size(0) > 1, "Mixup cannot be applied to a single instance."

    lam = np.random.beta(alpha, alpha)
    rand_index = torch.randperm(x1.size()[0])
    mixed_x1 = lam * x1 + (1 - lam) * x1[rand_index, :]
    mixed_x2 = lam * x2 + (1 - lam) * x2[rand_index, :]
    mixed_x3 = lam * x3 + (1 - lam) * x3[rand_index, :]
    mixed_x4 = lam * x4 + (1 - lam) * x4[rand_index, :]
    mixed_meta = lam * z + (1 - lam) * z[rand_index, :]
    target_a, target_b = y, y[rand_index]
    return mixed_x1,mixed_x2,mixed_x3, mixed_x4,mixed_meta, target_a, target_b,  lam

In [71]:
#Model
class CNNLSTM(LightningModule):
    def __init__(self):
        super().__init__()
        self.args = args
        self.scaler = target_scaler
        self.tabular_columns = tabluar_columns
        self._criterion = eval(self.args.loss)()
        self.transform = get_default_transforms()
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=.3)
        
        self.hidden_size = 2
        
        
        #Image Models
        self.model1 = timm.create_model(args.model_name1, 
                                       pretrained=args.pretrained, 
                                       num_classes=0,
                                       in_chans = 3,
                                       #global_pool=''
                                       )
        self.model2 = timm.create_model(args.model_name2, 
                                       pretrained=args.pretrained, 
                                       num_classes=0,
                                       in_chans = 3,
                                       #global_pool=''
                                       )
        self.model3 = timm.create_model(args.model_name2, 
                                       pretrained=args.pretrained, 
                                       num_classes=0,
                                       in_chans = 3,
                                       #global_pool=''
                                       )
        self.model4 = timm.create_model(args.model_name2, 
                                       pretrained=args.pretrained, 
                                       num_classes=0,
                                       in_chans = 3,
                                       #global_pool=''
                                       )
        #LSTM
        self.lstm = nn.LSTM(input_size = 2,
                            hidden_size = 64,
                            num_layers = args.lstm_layers,
                            batch_first=True,dropout=.1)
        #Possible multiple LSTM layers?
        #self.lstm = nn.LSTM(input_size = args.lstm_hidden,
        #                    hidden_size = self.hidden_size,
        ##                    num_layers = args.lstm_layers,
        #                    batch_first=True,dropout=.1)
        
        #Linear regression layer
        self.linear1 = nn.Linear(6406,1024)
        self.linear2 = nn.Linear(1024,256)
        self.linear3 = nn.Linear(1024,args.num_classes)
    
        
    def forward(self,features1,features2,features3,features4,meta,ts):
        
        
        
        #Image Convolution
        #Image Models
        features1 = self.model1(features1)                 
        features1 = self.relu(features1)
        features1 = self.dropout(features1)
        
        features2 = self.model2(features2)                 
        features2 = self.relu(features2)
        features2 = self.dropout(features2)
        
        features3 = self.model3(features3)                 
        features3 = self.relu(features3)
        features3 = self.dropout(features3)
        
        features4 = self.model4(features4)                 
        features4 = self.relu(features4)
        features4 = self.dropout(features4)
        

        #LSTM
        batch_size, seq_len, feature_len = ts.size()
        # Initialize hidden state with zeros
        
        h_0 = torch.zeros(1, batch_size, 64,requires_grad=True).cuda()
        c_0 = torch.zeros(1, batch_size, 64,requires_grad=True).cuda()
        
        f_ts, (final_hidden,final_cell) = self.lstm(ts, (h_0,c_0))
        f_ts = f_ts.contiguous().view(batch_size,-1)
        
        #*************************************************************
        #Concatenate meta and image features
        features = torch.cat([features1,features2,features3,features4,f_ts,meta],dim=1)
        #*************************************************************
        
        #Linear
        features = self.linear1(features)
        features = self.relu(features)
        features = self.dropout(features)
        
        features = self.linear2(features)
        features = self.relu(features)
        features = self.dropout(features)
        
        output = self.linear2(features)
        return output
    
###I DIDN"T MIX UP TS data
    def __share_step(self, batch, mode):
        copernicus_img, sentinel1_img, sentinel2a_img, sentinel2b_img, labels, meta,ts = batch
        labels = labels.float()
        meta = meta.float()
        ts = ts.float()
        copernicus_img = self.transform[mode](copernicus_img)
        sentinel1_img = self.transform[mode](sentinel1_img)
        sentinel2a_img = self.transform[mode](sentinel2a_img)
        sentinel2b_img = self.transform[mode](sentinel2b_img)

        rand_index = torch.rand(1)[0]
        
        #This is a mixup function
        if rand_index < 0.5 and mode == 'train':
            copernicus_mixed,sentinel1_mixed,sentinel2a_mixed,sentinel2b_mixed, mixed_meta, target_a, target_b, lam = mixup(
                                                          copernicus_img,sentinel1_img,sentinel2a_img,sentinel2b_img,
                                                          labels, meta, alpha=0.5)
            logits = self.forward(copernicus_mixed,sentinel1_mixed,sentinel2a_mixed,sentinel2b_mixed, mixed_meta,ts).squeeze(1)
            loss = self._criterion(logits, target_a) * lam + \
                (1 - lam) * self._criterion(logits, target_b)

        else:  
          logits = self.forward(copernicus_img,sentinel1_img,sentinel2a_img,sentinel2b_img, meta,ts).squeeze(1)
          loss = self._criterion(logits, labels)

        pred = torch.from_numpy(self.scaler \
            .inverse_transform(np.array(logits.sigmoid().detach().cpu()) \
            .reshape(-1, 1)))
        labels = torch.from_numpy(self.scaler \
            .inverse_transform(np.array(labels.detach().cpu()) \
            .reshape(-1, 1)))
        
        '''
        #This is random noise
        elif rand_index > 0.8 and mode == 'train':
            images = images + (torch.randn(images.size(0),3,args.imagesize,args.imagesize, 
                                           dtype = torch.float, device = device)*10)/100
            logits = self.forward(images, meta).squeeze(1)
            loss = self._criterion(logits, labels)
        '''

        return loss, pred, labels

    def training_step(self, batch, batch_idx):
        loss, pred, labels = self.__share_step(batch, 'train')
        self.trainr2(pred.cuda(),labels.cuda())
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        return {'loss': loss, 'pred': pred, 'labels': labels}
    
    def validation_step(self, batch, batch_idx):
        loss, pred, labels = self.__share_step(batch, 'val')
        self.valr2(pred.cuda(),labels.cuda())
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return {'pred': pred, 'labels': labels}

    def training_epoch_end(self, outputs):
        self.log('train_r2_epoch',self.trainr2)
        self.__share_epoch_end(outputs, 'train')

    def validation_epoch_end(self, outputs):
        self.log('val_r2_epoch',self.valr2)
        self.__share_epoch_end(outputs, 'val')

    def __share_epoch_end(self, outputs, mode):
        preds = []
        labels = []
        for out in outputs:
            pred, label = out['pred'], out['labels']
            preds.append(pred)
            labels.append(label)
        preds = torch.cat(preds)
        labels = torch.cat(labels)
        metrics = torch.sqrt(((labels - preds) ** 2).mean())
        self.log(f'{mode}_RMSE', metrics)    


    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=args.lr, weight_decay = args.weight_decay)
        
        return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": CosineAnnealingLR(optimizer, T_max = args.T_max, eta_min= args.eta_min),
            "interval": "step",
            "monitor": "train_loss",
            "frequency": 1}
            }

In [72]:
#Not doing kfold, instead separating by year
traindf = df[(pd.to_datetime(df['date']).dt.year == int('2019'))|
            (pd.to_datetime(df['date']).dt.year == int('2018'))].copy().reset_index(drop=True)
valdf = df[(pd.to_datetime(df['date']).dt.year == int('2016'))|
            (pd.to_datetime(df['date']).dt.year == int('2017'))].copy().reset_index(drop=True)

model = CNNLSTM()

modelname = 'sepimage-cnn-lstm'

#Callbacks
early_stop_callback = EarlyStopping(monitor="val_RMSE", min_delta=args.min_delta, patience=args.patience, 
                                    verbose=False, mode="min")
progressbar = TQDMProgressBar(refresh_rate = 10)
checkpoint_callback = ModelCheckpoint(dirpath='/home/ubuntu/snowcap/weights', 
                                      filename= f"{modelname}_best_weights", save_top_k=1, monitor="val_RMSE")
lr_monitor = LearningRateMonitor(logging_interval='step')

#Initialize wandb()
#wandb.init(name=modelname,project = "ASO_Modeling", entity = "snowcastshowdown", job_type='train')

#Log model parameters into wandb (args variable dictionary)
args_dict = dict(args.__dict__)
#pop out non-json-able variables
for key in ['__module__','__dict__','__weakref__','__doc__']:
    args_dict.pop(key,None)
#wandb.config.update(args_dict)


#wandb_logger = WandbLogger(log_model = 'all')

#wandb_logger.watch(model)

trainer = pl.Trainer(max_epochs=args.max_epochs, 
                    gpus=1, 
#                    logger=wandb_logger,
                    callbacks=[early_stop_callback, 
                                progressbar, 
                                checkpoint_callback,
                                lr_monitor])

SWE_Datamodule = SWEDataModule(traindf, valdf, weather, args = args, loaderargs = loaderargs)

trainer.fit(model, SWE_Datamodule)

#wandb.finish()

del model
torch.cuda.empty_cache()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type              | Params
-------------------------------------------------
0 | _criterion | BCEWithLogitsLoss | 0     
1 | relu       | ReLU              | 0     
2 | dropout    | Dropout           | 0     
3 | model1     | EfficientNet      | 2.6 M 
4 | model2     | EfficientNet      | 7.7 M 
5 | model3     | EfficientNet      | 7.7 M 
6 | model4     | EfficientNet      | 7.7 M 
7 | lstm       | LSTM              | 17.4 K
8 | linear1    | Linear            | 6.0 M 
9 | linear2    | Linear            | 1.0 K 
-------------------------------------------------
31.7 M    Trainable params
0         Non-trainable params
31.7 M    Total params
126.760   Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Validation sanity check: 0it [00:00, ?it/s]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (24x6406 and 5830x1024)