In [None]:
!pip install timm 
!pip install wandb --quiet
!pip install pytorch-lightning

In [None]:
#All of our imports
import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.nn import functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR

import torchvision
from torchvision import transforms as T
from torchvision.io import read_image

from torchmetrics import R2Score

import timm

from tqdm import tqdm_notebook as tqdm

import pytorch_lightning as pl
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import progress
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.callbacks import LearningRateMonitor

import sklearn
from sklearn.model_selection import StratifiedKFold

import wandb

In [None]:
#Reading in the data
df = pd.read_csv('',index_col=0)

#Designating which columns are our metadata
feature_cols = [col for col in df.columns 
                if col not in 
                ['cell_id', 'date', 'MOD10A1_filelocations', 'MYD10A1_filelocations', 
                 'copernicus_filelocations', 'SWE','sentinel1a_filelocation','SWE_Scaled']]

#Min max scaling the meta data
scaler = sklearn.preprocessing.MinMaxScaler()
df[feature_cols] =  scaler.fit_transform(df[feature_cols])

#We will create a separate scaler for the targets so that we can transform them back and forth
target_scaler = sklearn.preprocessing.MinMaxScaler()
target_scaler.fit(np.array(df['SWE']).reshape(-1, 1))
df['SWE_Scaled'] = target_scaler.transform(np.array(df['SWE']).reshape(-1, 1))

tabluar_columns = len(feature_cols)

In [None]:
class args:
    #Overall Args
    folder_name = "/"
  
    #Setting the number of CPU workers we are using
    num_workers = 4

    #Keep track of features used in the model
    features = feature_cols
    
    #Setting the seed so we can replicate
    seed = 1212

    #Toggle for whether or not we want our model pretrained on imagenet
    pretrained = True

    #Next we pick the model name with the appropriate shape, img size and output
    model_name = 'tf_efficientnet_b4_ns'
    model_shape = 1792 #768 for swin small 1536 for swin large 1792 for efficientnet b4 768 for cait-m-36
    imagesize = 224
    num_classes = 1
    #Channels for image stuffed into model
    channels = 12
    
    #LSTM variables
    lstm_hidden = 64
    lstm_layers = 1
    lstm_seqlen = 7

    #Training Args
    train_batch_size = 32
    val_batch_size = 32
    test_batch_size = 32

    #Max epochs and number of folds
    max_epochs = 80
    n_splits = 2
  
    #Optimizer and Scheduler args
    loss = 'nn.BCEWithLogitsLoss'
    lr = 3e-4
    warmup_epochs = 5
    weight_decay = 3e-6
    eta_min = 0.000001
    n_accumulate = 1
    T_0 = 25
    T_max = 1000

    #Callback args
    #Minimum number amount of improvement to not trigger patience
    min_delta = 0.0
    #Number of epochs in a row to wait for improvement
    patience = 25

#Dataloader Args
loaderargs = {'num_workers' : args.num_workers, 'pin_memory': False, 'drop_last': False}
device = torch.device("cuda:0")

seed_everything(args.seed)


In [None]:
#Data read-in goes here


In [None]:
#Dataset
class SWEDataset(torch.utils.data.Dataset):
    def __init__(self, df, test = False):
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        #Get the image, scale it to between 0-1 and resize it
        MOD10A1_img_path = self.MOD10A1_file_names[index]
        MOD10A1_img = read_image(MOD10A1_img_path, mode = torchvision.io.image.ImageReadMode.RGB) / 255
        MOD10A1_img = self._transform(MOD10A1_img)

        MYD10A1_img_path = self.MYD10A1_file_names[index]
        MYD10A1_img = read_image(MYD10A1_img_path, mode = torchvision.io.image.ImageReadMode.RGB) / 255
        MYD10A1_img = self._transform(MYD10A1_img)

        copernicus_img_path = self.copernicus_file_names[index]
        copernicus_img = read_image(copernicus_img_path, mode = torchvision.io.image.ImageReadMode.RGB) / 255
        copernicus_img = self._transform(copernicus_img)
        
        sentinel1_img_path = self.sentinel1_file_names[index]
        sentinel1_img = read_image(sentinel1_img_path, mode = torchvision.io.image.ImageReadMode.RGB) / 255
        sentinel1_img = self._transform(sentinel1_img)

        #Concatenate image into [X-channel,R,G,B] image (single)
        image = torch.cat((MOD10A1_img,MYD10A1_img,copernicus_img,sentinel1_img),dim=1)
        
        #Pull in the features for our batch
        #meta = self.meta[index, :]
        
        #Specify the target based on whether this is training or test
        #if self.test:
        #  target = 0
        #else:
        #  target = self.targets[index]
            
        return image, target, meta

In [None]:
#Data Loader goes here (Sliding Window for LSTM)
#Ideally loader output will need to be X-channel image (all satellite images) plus sequence of tabular data


In [None]:
#Model
class CNNLSTM(LightningModule):
    def __init__(self)
        super().__init__()
        self.args = args
        self.scaler = target_scaler
        self.tabular_columns = tabluar_columns
        self._criterion = eval(self.args.loss)()
        self.transform = get_default_transforms()
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=.3)
        
        #Image Models
        self.cnn = timm.create_model(args.model_name, 
                                       pretrained=args.pretrained, 
                                       num_classes=0,
                                       in_chans = args.channels,
                                       global_pool='')
        #LSTM
        self.lstm = nn.LSTM(input_size = len(args.features),
                            hidden_size = args.lstm_hidden,
                            num_layers = args.lstm_layers,
                            batch_first=True,dropout=.1)
        #Possible multiple LSTM layers?
        self.lstm = nn.LSTM(input_size = args.lstm_hidden,
                            hidden_size = args.lstm_hidden,
                            num_layers = args.lstm_layers,
                            batch_first=True,dropout=.1)
        
        #Linear regression layer
        self.linear1 = nn.Linear(args.lstm_hidden*args.lstm_seqlen,args.lstm_hidden*2)
        self.linear2 = nn.Linear(args.lstm_hidden*2,args.num_classes)
    
    def init_hidden(self,batch_size):
        hidden_state = torch.zeros(args.lstm_layers,args.batch_size,args.lstm_hidden)
        cell_state = torch.zeros(args.lstm_layers,args.batch_size,args.lstm_hidden)
        self.hidden = (hidden_state, cell_state)
        
    def forward(self,image,meta):

        #Image Convolution
        f_image = self.model(image)
        f_image = self.relu()
        
        #CNN output flatten
        f_image = f_image.view(features.size(0), -1)
        

        #LSTM
        batch_size, seq_len, _ = meta.size()
        
        f_meta, self.hidden = self.lstm(features, self.hidden)
        f_meta = lstm_out.contiguous().view(batch_size,-1)
        
        #*************************************************************
        #Concatenate meta and image features
        features = torch.cat([f_image,f_meta],dim=1)
        #*************************************************************
        
        #Linear
        features = self.linear1(features)
        features = self.relu(features)
        features = self.dropout(features)
        
        output = self.linear2(features)
        return output
        