In [2]:
import pytorch_lightning as pl
import torch.nn.functional as F
from torchvision import transforms
from torch import nn
from torch import optim
import progressbar


class CNN_UNET_(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Sequential(    
        nn.Conv2d(76, 64, kernel_size = 3, padding = 1, bias=False),
        nn.ReLU(True),
        nn.BatchNorm2d(64),
        nn.Conv2d(64, 64, kernel_size = 3, padding = 1, bias=False),
        nn.BatchNorm2d(64)
             )
        
        self.conv2 = nn.Sequential(    
        nn.Conv2d(64, 128, kernel_size = 3, padding = 1, bias=False),
        nn.ReLU(True),
        nn.BatchNorm2d(128),
        nn.Conv2d(128, 128, kernel_size = 3, padding = 1, bias=False),
        nn.BatchNorm2d(128)
             )
        
        self.conv3 = nn.Sequential(    
        nn.Conv2d(128, 256, kernel_size = 3, padding = 1, bias=False),
        nn.ReLU(True),
        nn.BatchNorm2d(256),
        nn.Conv2d(256, 256, kernel_size = 3, padding = 1, bias=False),
        nn.BatchNorm2d(256)
             )
        
        self.conv_up1 = nn.Sequential(    
        nn.Conv2d(128, 64, kernel_size = 3, padding = 1, bias=False),
        nn.ReLU(True),
        nn.BatchNorm2d(64),
        nn.Conv2d(64, 64, kernel_size = 3, padding = 1, bias=False),
        nn.BatchNorm2d(64),
        nn.Conv2d(64, 8, kernel_size = 3, padding = 1, bias=False)
             )
        
        self.conv_up2 = nn.Sequential(    
        nn.Conv2d(256, 128, kernel_size = 3, padding = 1, bias=False),
        nn.ReLU(True),
        nn.BatchNorm2d(128),
        nn.Conv2d(128, 128, kernel_size = 3, padding = 1, bias=False),
        nn.BatchNorm2d(128)
             )
        
        self.convTrans3 = nn.ConvTranspose2d(256,128,kernel_size = 2,padding = 0,stride=2)
        self.convTrans2 = nn.ConvTranspose2d(128,64,kernel_size = 2,padding = 0,stride=2)
        
        self.softmax = nn.Softmax(dim=2)
        self.flatten = nn.Flatten(start_dim=2, end_dim=- 1)
        self.maxpool2d = nn.MaxPool2d(2)
        self.avgpool2d = nn.AvgPool2d(2)
        self.relu = nn.ReLU(True)
        
        
    def forward(self, z, y_filter,y):
        
        y_pred = torch.zeros(y.shape)
        y_pred[:,:,49:51,49,51] = 0.25 # Initialisation of y_pred

        y1 = torch.cat((z,y_pred),dim=1) 
        y1 = self.conv1(y1)
        y2 = self.maxpool2d(y1)
        y2 = self.conv2(y2)
        y3 = self.maxpool2d(y2)
        y3 = self.conv3(y3)
        
        y3 = self.convTrans3(y3)
        y3 = torch.cat((y2,y3),dim=1)
        y2 = self.conv_up2(y3)
        
        y2 = self.convTrans2(y2)
        y2 = torch.cat((y1,y2),dim=1)
        y1 = self.conv_up1(y2)
        
        y_hat = self.flatten(y1)
        y_hat = self.relu(y_hat)
        y_hat = self.softmax(torch.log(y_hat+1e-10)) 
        y_hat = y_hat.view(y_hat.shape[0],8,nb_dx,nb_dx)
        
        return y_hat 
    
    def configure_optimizers(self):
        lr = 0.001
        optimizer = optim.Adam(self.parameters(),lr= lr, betas=(0.5, 0.999),weight_decay=0)
        return optimizer
    
    def training_step(self, batch, batch_idx):
        
        z, y_f, y = batch
        y_hat = self(z,y_f,y)
        loss = 0

        
        for i in range(0,8):
            loss = loss + alpha1*Bhatta_loss(y_hat[:,i,:,:], y_f[:,i,:,:]) + alpha2*Bhatta_loss(y_hat[:,i,:,:], y[:,i,:,:])
            
        loss = loss / 8
        
        loss_filter_200m = Bhatta_loss(y_hat[:,-1,:,:], y_f[:,-1,:,:])
        
        self.log("loss_train", loss, on_epoch=True, on_step = True)
        self.log("loss_filter_200m_train", loss, on_epoch=True, on_step = True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        z, y_f, y = batch
        y_hat = self(z,y_f,y)
        
        loss_filter = 0
        loss_no_filter = 0

        for i in range(0,8):
            loss_filter = loss_filter + Bhatta_loss(y_hat[:,i,:,:], y_f[:,i,:,:])
            loss_no_filter = loss_no_filter + Bhatta_loss(y_hat[:,i,:,:], y[:,i,:,:])
            
        loss_filter = loss_filter / 8
        loss_no_filter = loss_no_filter / 8
        
        loss_filter_200m = Bhatta_loss(y_hat[:,-1,:,:], y_f[:,-1,:,:])
        
        self.log("loss_filter_validation", loss_filter, on_epoch=True, on_step = True)
        self.log("loss_no_filter_validation", loss_no_filter, on_epoch=True, on_step = True)
        self.log("loss_filter_200m_validation", loss_filter_200m, on_epoch=True, on_step = True)
        
        return loss_no_filter