In [None]:
import os
import re
import gc

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.transforms import ToTensor
from torchinfo import summary

from tqdm import tqdm

# Data Loading

In [None]:
class SkyDataset(Dataset):
    def __init__(self, train_dir, gt_dir):
        self.train_root = train_dir
        self.gt_root = gt_dir
        self.train_dirs = sorted(os.listdir(train_dir))
        self.gt_dirs = sorted(os.listdir(gt_dir))
        
    def __len__(self):
        return len(self.train_dirs)
    
    def __getitem__(self, idx):
        train_seq = torch.stack([read_image(os.path.join(self.train_root, self.train_dirs[idx], x))/255.0 for x in sorted(os.listdir(os.path.join(self.train_root, self.train_dirs[idx])))])
        gt_seq = torch.stack([read_image(os.path.join(self.gt_root, self.gt_dirs[idx], x))/255.0 for x in sorted(os.listdir(os.path.join(self.gt_root, self.gt_dirs[idx])))])
        return train_seq, gt_seq


In [None]:
input_dir = '../SkyDataset/train'
gt_dir = '../SkyDataset/gt'
train_set = SkyDataset(input_dir, gt_dir)
train_dataloader = DataLoader(train_set, batch_size=2, shuffle=True)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
gc.collect()
torch.cuda.empty_cache()

# Model Construction

In [None]:
class MotionEncoder(nn.Module):
    def __init__(self, in_c=2, out_c=8):
        super(MotionEncoder, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_c, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(256, 256, kernel_size=8, stride=8, padding=0),
            nn.LeakyReLU()
        )
        if in_c > 2:
            in_dim = 2560
        else:
            in_dim = 256
        self.fc = nn.Linear(in_dim, out_c) # latent
    
    def forward(self, x):
        x_c = self.conv_layers(x)
        x_flatten = x_c.view(x.size(0), -1)
        out = self.fc(x_flatten)
        return out

In [None]:
class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size):
        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size // 2

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4*self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state
        combined = torch.cat([input_tensor, h_cur], dim=1)

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width),
                torch.zeros(batch_size, self.hidden_dim, height, width))


In [None]:
class CAE(nn.Module):
    def __init__(self, nf, in_chan):
        super(CAE, self).__init__()
        self.nf = nf
        self.e1 = ConvLSTMCell(input_dim=in_chan, hidden_dim=nf, kernel_size=3)
        self.e2 = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=3)
        self.d1 = ConvLSTMCell(input_dim=nf,hidden_dim=nf, kernel_size=3)
        self.d2 = ConvLSTMCell(input_dim=nf,hidden_dim=nf,kernel_size=3)
        self.d3 = nn.Conv3d(in_channels=nf, out_channels=3, kernel_size=3,padding=1)

    def autoencoder(self, x, seq_len, future_step, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4):
        outputs = []
        for t in range(seq_len):
            h_t, c_t = self.e1(input_tensor=x[:, t, :, :], cur_state=[h_t, c_t])
            h_t2, c_t2 = self.e2(input_tensor=h_t, cur_state=[h_t2, c_t2])
        
        for t in range(future_step):
            h_t3, c_t3 = self.d1(input_tensor=h_t2, cur_state=[h_t3, c_t3])
            h_t4, c_t4 = self.d2(input_tensor=h_t3, cur_state=[h_t4, c_t4])
            outputs += [h_t4]
        
        outputs = torch.stack(outputs, 1)
        outputs = outputs.permute(0, 2, 1, 3, 4)
        outputs = self.d3(outputs)
        outputs = torch.nn.Sigmoid()(outputs)
        return outputs

    def forward(self, x, future_seq=10):
        b, seq_len, c, h, w = x.size()

        h_t, c_t = self.e1.init_hidden(batch_size=b, image_size=(h, w))
        h_t2, c_t2 = self.e2.init_hidden(batch_size=b, image_size=(h, w))
        h_t3, c_t3 = self.d1.init_hidden(batch_size=b, image_size=(h, w))
        h_t4, c_t4 = self.d2.init_hidden(batch_size=b, image_size=(h, w))

        outputs = self.autoencoder(x, seq_len, future_seq, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4)
        return outputs

In [None]:
from torchvision.models.optical_flow import raft_small, Raft_Small_Weights
from torchvision.utils import flow_to_image
from torchvision import transforms

weights = Raft_Small_Weights.DEFAULT
flow_model = raft_small(weights=Raft_Small_Weights.DEFAULT, progress=False).to(device)
flow_model = flow_model.eval()
flow_tf = weights.transforms()

# Train

### Hyperparams

In [None]:
lr=1e-4
num_epochs=800
img_w = 640
img_h = 360

In [None]:
def train_net(me, cae, optim_me, optim_cae):
    criterion = nn.MSELoss()
    train_loss = []
    
    for epoch in range(num_epochs):
        total_loss = 0
        for inputs, gts in train_dataloader:
            inputs = inputs.to(device)
            gts = gts.to(device)
            
            # find flow
            b1 = inputs[:, 0]
            b2 = inputs[:, 1]
            b1, b2 = flow_tf(b1, b2)
            flow = torch.stack(flow_model(b1, b2))
            flow = flow[-1, :]
            flow = F.interpolate(flow, size=(128,128), mode='bilinear', align_corners=True)
            z = me(flow)

            # stack
            z_matched = z.view(z.size(0),1, z.size(1), 1, 1).expand(inputs.size(0), inputs.size(1), z.size(1), inputs[0].size(2), inputs[0].size(3))      
            inputs = torch.cat((inputs, z_matched), 2)
            
            # prediction
            preds = cae(inputs).permute(0, 2, 1, 3, 4)
            optim_me.zero_grad()
            optim_cae.zero_grad()

            loss = criterion(preds, gts)
            loss += criterion((preds[:, -1] - preds[:, 0]).detach(), (gts[:, -1] - gts[:, 0]).detach())

            loss.backward()
            optim_me.step()
            optim_cae.step()
            
            total_loss += loss.item()

        train_loss.append(float(total_loss) / len(train_dataloader))
        print("Epoch {}: Train loss: {}". format(epoch + 1, train_loss[epoch]))

        torch.save({
            'epoch': epoch,
            'me': me.state_dict(),
            'cae': cae.state_dict(),
            'optim_me': optim_me.state_dict(),
            'optim_cae': optim_cae.state_dict()
            }, 
            f'/kaggle/working/checkpoint_cae_{epoch}.pt')
        np.savetxt(f'loss_{epoch}.txt', train_loss)

In [None]:
me = MotionEncoder()
me = me.to(device)
cae = CAE(nf = 22, in_chan = 11)
cae = cae.to(device)

optim_me = Adam(me.parameters(), lr=1e-4)
optim_cae = Adam(cae.parameters(), lr=1e-4)

checkpoint = torch.load('checkpoints/checkpoint_cae_302.pt')
me.load_state_dict(checkpoint['me'])
cae.load_state_dict(checkpoint['cvae'])  # checkpoints contain a mistake in naming here
optim_me.load_state_dict(checkpoint['optim_me'])
optim_cae.load_state_dict(checkpoint['optim_cvae'])

In [None]:
torch.cuda.empty_cache()

In [None]:
train_net(me, cae, optim_me, optim_cae)