In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
from torchvision.models import vgg19
import torch.nn.functional as F
import itertools
import random
import numpy as np
import os
import time

!pip install kornia
from torch.autograd import Variable
import torch.nn.functional as F
from google.colab import drive, runtime
import seaborn as sns
import matplotlib.pyplot as plt
import sys

from loss import DivergenceLoss, DivergenceLoss2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


train_direc = '/content/gdrive/MyDrive/Research/ResearchData/TFNetData/data_point_'

In [None]:
class Dataset(data.Dataset):
    def __init__(self, indices, mid, direc, time_window):
        self.mid = mid
        self.direc = direc
        self.list_IDs = indices
        self.time_window = time_window
        
    def __len__(self):
        return len(self.list_IDs)

    def __getitem__(self, index):
        ID = self.list_IDs[index]
        y = torch.load(self.direc + str(ID) + ".pt")[self.mid+self.time_window]
        x = torch.load(self.direc + str(ID) + ".pt")[:self.mid].reshape(-1, y.shape[-2], y.shape[-1])
        
        y = y.reshape(-1, y.shape[-2], y.shape[-1])
        return x.float(), y.float()

In [None]:
def conv_block(in_channels, out_channels):
  conv_block = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride=2, padding=1),
      nn.ReLU(),
      nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride=1, padding=1),
      nn.ReLU(),
  )
  return conv_block

class FeatureExtractor(nn.Module):
    def __init__(self, in_channels):
        super(FeatureExtractor, self).__init__()
        self.conv1 = conv_block(in_channels,64)
        self.conv2 = conv_block(64,128)
        self.conv3 = conv_block(128,256)
        self.conv4 = conv_block(256,512)
        self.conv5 = conv_block(512,1024)
        self.fcc1  = nn.Linear(4096, 4096)
        self.fcc2  = nn.Linear(4096, 4096)        
        self.fcc3  = nn.Linear(4096, 2000)
        self.relu  = nn.ReLU()
        self.maxpool  = nn.MaxPool2d(kernel_size=3, stride = 1, padding=1)

    def forward(self, img):
        out = self.conv1(img)
        out = self.maxpool(out)
        out = self.conv2(out)
        out = self.maxpool(out)
        out = self.conv3(out)
        out = self.maxpool(out)
        out = self.conv4(out)
        out = self.maxpool(out)
        out = self.conv5(out)
        out = self.maxpool(out)
        out = out.reshape(out.shape[0], -1)
        out = self.fcc1(out)
        out = self.relu(out)
        out = self.fcc2(out)
        out = self.relu(out)
        out = self.fcc3(out)
        return out


def conv(in_planes, out_planes, kernel_size=3, stride=1):
    return nn.Sequential(
        nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
                  stride=stride, padding=(kernel_size - 1) // 2, bias=False),
        nn.BatchNorm2d(out_planes),
        nn.LeakyReLU(0.1, inplace=True),
    )

def deconv(in_planes, out_planes):
    return nn.Sequential(
        nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4,
                           stride=2, padding=1, bias=True),
        nn.LeakyReLU(0.1, inplace=True),
    )

def predict_flow(in_planes, out_planes):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)


class LES(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, time_range):
        super(LES, self).__init__()
        self.spatial_filter = nn.Conv2d(1, 1, kernel_size = 3, padding = 1, bias = False)   
        self.temporal_filter = nn.Conv2d(time_range, 1, kernel_size = 1, padding = 0, bias = False)
        self.input_channels = input_channels
        self.time_range = time_range
        
        
    def forward(self, xx):
        xx_len = xx.shape[1]
        
        u_tilde = self.spatial_filter(xx.reshape(xx.shape[0]*xx.shape[1], 1, 64, 64)).reshape(xx.shape[0], xx.shape[1], 64, 64)

        u_prime = (xx - u_tilde)
        
        u_tilde2 = u_tilde.reshape(u_tilde.shape[0], u_tilde.shape[1]//2, 2, 64, 64)
        
        u_mean = torch.mean(u_tilde2,1)
        u_mean = u_mean.unsqueeze(1)
        u_mean = u_mean.repeat(1,xx.shape[1]//2,1,1,1)
        u_mean = u_mean.reshape(u_mean.shape[0],u_mean.shape[1]*u_mean.shape[2],64,64)
                
        u_tilde = u_tilde - u_mean

        return u_tilde, u_mean, u_prime

class U_net(nn.Module):
    def __init__(self, filter, input_channels=4, output_channels=2):
        super(U_net, self).__init__()
        self.filter = filter
        self.input_channels = input_channels
        self.conv1 = conv(input_channels, 64, kernel_size=3, stride=2)
        self.conv2 = conv(64, 128, kernel_size=3, stride=2)
        self.conv3 = conv(128, 256, kernel_size=3, stride=2)
        self.conv3_1 = conv(256, 256, kernel_size=3)
        self.conv4 = conv(256, 512, kernel_size=3, stride=2)
        self.conv4_1 = conv(512, 512, kernel_size=3)
        self.conv5 = conv(512, 1024, kernel_size=3, stride=2)
        #self.conv5_1 = conv(1024, 1024)

        self.deconv4 = deconv(1024, 256)
        self.deconv3 = deconv(768, 128)
        self.deconv2 = deconv(384, 64)
        self.deconv1 = deconv(192, 32)
        self.deconv0 = deconv(128, 16)
    
        self.predict_flow0 = predict_flow(16 + input_channels, output_channels)

    def forward(self, x):
        u_tilde, u_mean, u_prime = self.filter(x)

        out_conv_tilde_1 = self.conv1(u_tilde)
        out_conv_tilde_2 = self.conv2(out_conv_tilde_1)
        out_conv_tilde_3 = self.conv3_1(self.conv3(out_conv_tilde_2))
        out_conv_tilde_4 = self.conv4_1(self.conv4(out_conv_tilde_3))
        out_conv_tilde_5 = self.conv5(out_conv_tilde_4)

        out_conv_mean_1 = self.conv1(u_mean)
        out_conv_mean_2 = self.conv2(out_conv_mean_1)
        out_conv_mean_3 = self.conv3_1(self.conv3(out_conv_mean_2))
        out_conv_mean_4 = self.conv4_1(self.conv4(out_conv_mean_3))
        out_conv_mean_5 = self.conv5(out_conv_mean_4)

        out_conv_prime_1 = self.conv1(u_prime)
        out_conv_prime_2 = self.conv2(out_conv_prime_1)
        out_conv_prime_3 = self.conv3_1(self.conv3(out_conv_prime_2))
        out_conv_prime_4 = self.conv4_1(self.conv4(out_conv_prime_3))
        out_conv_prime_5 = self.conv5(out_conv_prime_4)

        

        out_deconv4 = self.deconv4(out_conv_prime_5 + out_conv_mean_5 + out_conv_tilde_5)
        concat4 = torch.cat((out_conv_prime_4 + out_conv_mean_4 + out_conv_tilde_4, out_deconv4), 1)
        out_deconv3 = self.deconv3(concat4)
        concat3 = torch.cat((out_conv_prime_3 + out_conv_mean_3 + out_conv_tilde_3, out_deconv3), 1)
        out_deconv2 = self.deconv2(concat3)
        concat2 = torch.cat((out_conv_prime_2 + out_conv_mean_2 + out_conv_tilde_2, out_deconv2), 1)
        out_deconv1 = self.deconv1(concat2)
        concat1 = torch.cat((out_conv_prime_1 + out_conv_mean_1 + out_conv_tilde_1, out_deconv1, out_deconv1), 1)
        out_deconv0 = self.deconv0(concat1)
        concat0 = torch.cat((x, out_deconv0), 1)
        flow0 = self.predict_flow0(concat0)

        return flow0

class Generator(nn.Module):
    def __init__(self, input_channels, filter):
        super(Generator, self).__init__()
        self.model = U_net(filter = filter,input_channels = input_channels)
        
    def forward(self, xx):
        im = self.model(xx)
            
        return im

class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()
        self.activ = nn.LeakyReLU(0.1, inplace=True)
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size = 5, padding = 2, stride = 2),
            nn.BatchNorm2d(32)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size = 5, padding = 2, stride = 2),
            nn.BatchNorm2d(64)
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size = 5, padding = 2, stride = 2),
            nn.BatchNorm2d(128)
        )
        
        self.conv4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size = 5, padding = 2),
            nn.BatchNorm2d(256)
        )
        
        self.dense_layer = nn.Sequential(
            nn.Linear(256*8*8, 1),
            nn.Sigmoid()
        )

    def forward(self, ims):
        out = self.activ(self.conv1(ims))
        out = self.activ(self.conv2(out))
        out = self.activ(self.conv3(out))
        out = self.activ(self.conv4(out))
        
        out = out.reshape(out.shape[0], -1)
        out = self.dense_layer(out)
        return out

    
def noise(bz, div):
    return torch.rand(bz,1)/div


In [None]:
batch_size = 30
losses = []
min_mse = 1
mid = 30

lambda_1 = 0.001
lambda_2 = 0.001

lr_g = 0.01  
lr_ds = 0.003
kernel_size = 3
time_window = 5

train_indices = list(range(0,  50))
valid_indices = list(range(50, 55))
test_indices = list(range(55, 60))

train_set = Dataset(train_indices, mid, train_direc, time_window)
valid_set = Dataset(valid_indices, mid, train_direc, time_window)
test_set = Dataset(test_indices, mid, train_direc, time_window)
train_loader = data.DataLoader(train_set, batch_size = batch_size, shuffle = True)
valid_loader = data.DataLoader(valid_set, batch_size = batch_size, shuffle = False)
test_loader = data.DataLoader(test_set, batch_size = batch_size, shuffle = False)

feature_extractor = FeatureExtractor(in_channels=2).to(device)
filter = LES(input_channels=mid*2, output_channels=2, kernel_size=3, time_range=5)
Gen = Generator(mid*2, filter).to(device)
Dis = Discriminator(2).to(device)


optimizer_G = torch.optim.Adam(Gen.parameters(), lr = lr_g, betas=(0.9, 0.999), weight_decay=4e-4)
optimizer_D = torch.optim.Adam(Dis.parameters(), lr = lr_ds, betas=(0.9, 0.999), weight_decay=4e-4)

scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size= 1, gamma=0.9)
scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer_D, step_size= 1, gamma=0.9)

loss_fun = torch.nn.BCELoss()
loss_mse = torch.nn.MSELoss()


In [None]:
train_mse = []
valid_mse = []

for epoch in range(200):
  print('            ')
  print('%%%%%%%%%%%%')
  print('Epoch: {}'.format(epoch))
  start = time.time()
  scheduler_G.step()
  scheduler_D.step()

  mse = []
  for xx, real_imgs in train_loader:
    xx, real_imgs = xx.to(device), real_imgs.to(device)

    valid = Variable(torch.Tensor(xx.size(0), 1).fill_(1.0) - noise(xx.size(0),10),requires_grad=False).to(device)
    fake = Variable(torch.Tensor(xx.size(0), 1).fill_(0.0) + noise(xx.size(0), 10),requires_grad=False).to(device)

    optimizer_G.zero_grad()
    gen_imgs = Gen(xx)

    div_fake_imgs = gen_imgs.reshape(gen_imgs.shape[0],1, 2, 64, 64)
    div_real_imgs = real_imgs.reshape(gen_imgs.shape[0],1, 2, 64, 64)

    div_loss = 0
    div_loss_fun = DivergenceLoss2(loss_mse)
    for y_hat, y_true in zip(div_fake_imgs, div_real_imgs):
      div_loss += div_loss_fun(y_hat, y_true).item()

    gen_features = feature_extractor(gen_imgs)
    real_features = feature_extractor(real_imgs)

    g_loss1 = loss_mse(gen_imgs, real_imgs)  
    g_loss2 = loss_mse(gen_features, real_features)
    g_loss3 = div_loss    

    g_loss = g_loss1 +  lambda_1*g_loss2 + lambda_2*div_loss
    
    g_loss.backward()
    optimizer_G.step()
    
    optimizer_D.zero_grad()
    d_real = Dis(real_imgs)
    d_fake = Dis(gen_imgs.detach())

    real_loss_s = loss_fun(d_real, valid)
    fake_loss_s = loss_fun(d_fake, fake)
    ds_loss = real_loss_s + fake_loss_s

    ds_loss.backward()
    optimizer_D.step()
    
    mse.append(loss_mse(gen_imgs, real_imgs).item())
    gan_loss = [round(g_loss1.item(),3), round(real_loss_s.item(), 3), round(fake_loss_s.item(), 3)]
    losses.append(gan_loss)
  mse = round(np.sqrt(np.mean(mse)),5)
  train_mse.append(mse)
  end = time.time()
  print('Time taken:{} s'.format(round((end-start),2)))
  print('MSE Loss: {}, Generator Loss: {}, Discriminator Loss: {}'.format(mse, gan_loss[0], gan_loss[1]+gan_loss[2]))
  if epoch % 20 == 0:
    print('            ')
    print('%%%%%%%%%%%%')
    print('Validation_round')
    start = time.time()
    mse = []
    with torch.no_grad():
      for xx, real_imgs in valid_loader:

        xx, real_imgs = xx.to(device), real_imgs.to(device)

        valid = Variable(torch.Tensor(xx.size(0), 1).fill_(1.0) - noise(xx.size(0),10),requires_grad=False).to(device)
        fake = Variable(torch.Tensor(xx.size(0), 1).fill_(0.0) + noise(xx.size(0), 10),requires_grad=False).to(device)

        gen_imgs = Gen(xx)

        div_fake_imgs = gen_imgs.reshape(gen_imgs.shape[0], 1, 2, 64, 64)
        div_real_imgs = real_imgs.reshape(gen_imgs.shape[0],1, 2, 64, 64)

        div_loss = 0
        div_loss_fun = DivergenceLoss2(loss_mse)
        for y_hat, y_true in zip(div_fake_imgs, div_real_imgs):
          div_loss += div_loss_fun(y_hat, y_true).item()

        gen_features = feature_extractor(gen_imgs)
        real_features = feature_extractor(real_imgs)

        g_loss1 = loss_mse(gen_imgs, real_imgs)  
        g_loss2 = loss_mse(gen_features, real_features)
        g_loss3 = div_loss        

        g_loss = g_loss1 +  lambda_1*g_loss2 + lambda_2*div_loss
                
        d_real = Dis(real_imgs)
        d_fake = Dis(gen_imgs.detach())

        real_loss_s = loss_fun(d_real, valid)
        fake_loss_s = loss_fun(d_fake, fake)
        ds_loss = real_loss_s + fake_loss_s
        
        mse.append(loss_mse(gen_imgs, real_imgs).item())
        gan_loss = [round(g_loss1.item(),3), round(real_loss_s.item(), 3), round(fake_loss_s.item(), 3)]
        losses.append(gan_loss)
      mse = round(np.sqrt(np.mean(mse)),5)
      train_mse.append(mse)
      end = time.time()
      print('Time taken:{} s'.format(round((end-start),2)))
      print('MSE Loss: {}, Generator Loss: {}, Discriminator Loss: {}'.format(mse, gan_loss[0], gan_loss[1]+gan_loss[2]))

In [None]:
def test_epoch(model):
  with torch.no_grad():
    mse = []
    for xx, real_imgs in valid_loader:
      xx, real_imgs = xx.to(device), real_imgs.to(device)
      gen_imgs = model(xx)      
      mse.append(loss_mse(gen_imgs, real_imgs).item())
    mse = round(np.sqrt(np.mean(mse)),5)
    print('MSE Loss: {}'.format(mse))
    return mse,gen_imgs, real_imgs


In [None]:
print(preds.shape)

In [None]:
loss, preds, ims = test_epoch(Gen)
print(loss)
fig, axs = plt.subplots(2, 1)
sns.heatmap(ax=axs[0], data=ims[2,1,:,:].cpu().numpy())
sns.heatmap(ax=axs[1], data=preds[2,1,:,:].cpu().numpy())





In [None]:
torch.save(Gen, '/content/gdrive/MyDrive/PINNGan_model.pth')