In [1]:
import os
import gc

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.transforms import ToTensor
from torchinfo import summary

from tqdm import tqdm

# Data Loading

In [2]:
class SkyDataset(Dataset):
    def __init__(self, train_dir, gt_dir):
        self.train_root = train_dir
        self.gt_root = gt_dir
        self.train_dirs = sorted(os.listdir(train_dir))
        self.gt_dirs = sorted(os.listdir(gt_dir))
        
    def __len__(self):
        return len(self.train_dirs)
    
    def __getitem__(self, idx):
        train_seq = torch.stack([read_image(os.path.join(self.train_root, self.train_dirs[idx], x))/255.0 for x in sorted(os.listdir(os.path.join(self.train_root, self.train_dirs[idx])))])
        gt_seq = torch.stack([read_image(os.path.join(self.gt_root, self.gt_dirs[idx], x))/255.0 for x in sorted(os.listdir(os.path.join(self.gt_root, self.gt_dirs[idx])))])
        return train_seq, gt_seq


In [3]:
input_dir = '../SkyDataset/train'
gt_dir = '../SkyDataset/gt'
train_set = SkyDataset(input_dir, gt_dir)
train_dataloader = DataLoader(train_set, batch_size=2, shuffle=True)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
gc.collect()
torch.cuda.empty_cache()

# Model Construction

In [4]:
class MotionEncoder(nn.Module):
    def __init__(self, in_c=2, out_c=8):
        super(MotionEncoder, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_c, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(256, 256, kernel_size=8, stride=8, padding=0),
            nn.LeakyReLU()
        )
        if in_c > 2:
            in_dim = 2560
        else:
            in_dim = 256
        self.fc = nn.Linear(in_dim, out_c) # latent
        self.fc_variance = nn.Linear(in_dim, out_c)

    def forward(self, x):
        x_c = self.conv_layers(x)
        x_flatten = x_c.view(x.size(0), -1)
        out = self.fc(x_flatten)
        var = self.fc_variance(x_flatten)
        return out, var

In [5]:
class MotionDecoder(nn.Module):
    def __init__(self, in_c=1280, out_c=64):
        super(MotionDecoder, self).__init__()
        self.fc = nn.Linear(in_c, 2560) # latent
        
        self.conv_layers = nn.Sequential(
            nn.ConvTranspose2d(256, 256, kernel_size=8, stride=8, padding=0, output_padding=0),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(256, 256, kernel_size=3, stride=2, padding=1, output_padding=(1,0)),
            nn.LeakyReLU(),
            nn.BatchNorm2d(256),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=(1,0)),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=(1,0)),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(64, out_c, kernel_size=5, stride=2, padding=2, output_padding=(1,0)),
            nn.BatchNorm2d(out_c),
            nn.LeakyReLU()
        )
    
    def forward(self, z):
        x = self.fc(z)
        x = x.view(-1, 256, 2, 5)
        x_c = self.conv_layers(x)
        # upsample to match the original input size
        x_c = F.interpolate(x_c, size=(360, 640), mode='bilinear', align_corners=True)
        return x_c


In [6]:
class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size):
        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size // 2

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4*self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state
        combined = torch.cat([input_tensor, h_cur], dim=1)

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width),
                torch.zeros(batch_size, self.hidden_dim, height, width))


In [7]:
class CVAE(nn.Module):
    def __init__(self, nf, in_chan):
        super(CVAE, self).__init__()
        self.nf = nf
        self.e1 = ConvLSTMCell(input_dim=in_chan, hidden_dim=nf,kernel_size=3)
        self.e2 = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=3)
        self.e_VAE = MotionEncoder(in_c=10, out_c=1280)
        self.d_VAE = MotionDecoder(in_c=1280, out_c=10)
        self.d1 = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=3)
        self.d2 = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=3)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        noise = torch.randn_like(std).cuda()
        return mu + noise * std

    def autoencoder(self, x, seq_len, future_step, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4):
        outputs = []
        for t in range(seq_len):
            h_t, c_t = self.e1(input_tensor=x[:, t, :, :], cur_state=[h_t, c_t])
            h_t2, c_t2 = self.e2(input_tensor=h_t, cur_state=[h_t2, c_t2])
        mu, logvar = self.e_VAE(h_t2)
        z = self.reparameterize(mu, logvar)
        z = self.d_VAE(z)
        h_t2 = z
        for t in range(future_step):
            h_t3, c_t3 = self.d1(input_tensor=h_t2, cur_state=[h_t3, c_t3])
            h_t4, c_t4 = self.d2(input_tensor=h_t3, cur_state=[h_t4, c_t4])
            outputs += [h_t4]

        outputs = torch.stack(outputs, 1)
        outputs = outputs.permute(0, 2, 1, 3, 4)
        outputs = self.decoder_CNN(outputs)
        outputs = torch.nn.Sigmoid()(outputs)
        return outputs

    def forward(self, x, future_seq=10):
        b, seq_len, c, h, w = x.size()

        h_t, c_t = self.e1.init_hidden(batch_size=b, image_size=(h, w))
        h_t2, c_t2 = self.e2.init_hidden(batch_size=b, image_size=(h, w))
        h_t3, c_t3 = self.d1.init_hidden(batch_size=b, image_size=(h, w))
        h_t4, c_t4 = self.d2.init_hidden(batch_size=b, image_size=(h, w))

        outputs = self.autoencoder(x, seq_len, future_seq, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4)
        return outputs

In [8]:
from torchvision.models.optical_flow import raft_small, Raft_Small_Weights
from torchvision.utils import flow_to_image
from torchvision import transforms

weights = Raft_Small_Weights.DEFAULT
flow_model = raft_small(weights=Raft_Small_Weights.DEFAULT, progress=False).to(device)
flow_model = flow_model.eval()
flow_tf = weights.transforms()

Downloading: "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth" to /root/.cache/torch/hub/checkpoints/raft_small_C_T_V2-01064c6d.pth


# Train

### Hyperparams

In [10]:
lr=1e-4
num_epochs=200
img_w = 640
img_h = 360

In [12]:
def train_net(me, cvae, optim_me, optim_cvae):
    criterion = nn.MSELoss()
    train_loss = []
    
    for epoch in range(num_epochs):
        total_loss = 0
        for inputs, gts in train_dataloader:
            inputs = inputs.to(device)
            gts = gts.to(device)
            
            # find flow
            b1 = inputs[:, 0]
            b2 = inputs[:, 1]
            b1, b2 = flow_tf(b1, b2)
            flow = torch.stack(flow_model(b1, b2))
            flow = flow[-1][0].unsqueeze(0)
            flow = F.interpolate(flow, size=(128,128), mode='bilinear', align_corners=True)
            z, logvar = me(flow)

            # stack
            z_matched = z.view(z.size(0), z.size(1), 1, 1).expand(inputs.size(0), inputs.size(1), z.size(1), inputs[0].size(2), inputs[0].size(3))      
            inputs = torch.cat((inputs, z_matched), 2)
            
            # inference
            preds = cvae(inputs).permute(0, 2, 1, 3, 4)
            optim_me.zero_grad()
            optim_cvae.zero_grad()
            loss = criterion(preds, gts)  # reconstruction MSE loss
            loss += torch.mean(-0.5 * torch.sum(1 + logvar - z ** 2 - logvar.exp(), dim = 1), dim = 0)  # KL
            loss.backward()
            optim_me.step()
            optim_cvae.step()
            
            total_loss += loss.item()

        train_loss.append(float(total_loss) / len(train_dataloader))
        print("Epoch {}: Train loss: {}". format(epoch + 1, train_loss[epoch]))
        if (epoch + 1)%10 == 0:
            torch.save({
                'epoch': epoch,
                'me': me.state_dict(),
                'cvae': cvae.state_dict(),
                'optim_me': optim_me.state_dict(),
                'optim_cvae': optim_cvae.state_dict()
                }, 
                f'/checkpoints/checkpoint_cvae_{epoch}.pt')
        np.savetxt(f'loss_{epoch}.txt', train_loss)


In [13]:
me = MotionEncoder()
me = me.cuda(0)
cvae = CVAE(nf = 10, in_chan = 11)
cvae = cvae.cuda(0)

optim_me = Adam(me.parameters(), lr=1e-4)
optim_cvae = Adam(cvae.parameters(), lr=1e-4)

# checkpoint = torch.load('checkpoints/checkpoint_cvae_0.pt')
# me.load_state_dict(checkpoint['me'])
# cvae.load_state_dict(checkpoint['cvae'])
# optim_me.load_state_dict(checkpoint['optim_me'])
# optim_cvae.load_state_dict(checkpoint['optim_cvae'])

In [14]:
gc.collect()
torch.cuda.empty_cache()

In [15]:
train_net(me, cvae, optim_me, optim_cvae)



Epoch 1: Train loss: 0.005803676136828801


Epoch 2: Train loss: 0.005804224375714647


Epoch 3: Train loss: 0.005815277728153036


Epoch 4: Train loss: 0.005852719964063231


Epoch 5: Train loss: 0.005864058790649189


Epoch 6: Train loss: 0.005861134889555421


Epoch 7: Train loss: 0.005821229141958533


Epoch 8: Train loss: 0.0057949260590558355


Epoch 9: Train loss: 0.005796854089985186


Epoch 10: Train loss: 0.0057603183608660674


Epoch 11: Train loss: 0.00574735732986889


Epoch 12: Train loss: 0.0057680506298833705


Epoch 13: Train loss: 0.005770507117020006


Epoch 14: Train loss: 0.005762431614021671


Epoch 15: Train loss: 0.005773173513366504


Epoch 16: Train loss: 0.005779150041176918


Epoch 17: Train loss: 0.005764734683597976


Epoch 18: Train loss: 0.005706065482994977


Epoch 19: Train loss: 0.005719612614112966


Epoch 20: Train loss: 0.00571348830819764


Epoch 21: Train loss: 0.00571715617255169


Epoch 22: Train loss: 0.005729022412065496


Epoch 23: Train loss: 0.005687997368303068


Epoch 24: Train loss: 0.005710234564352543


Epoch 25: Train loss: 0.005712788501516619


Epoch 26: Train loss: 0.005745072486473525


Epoch 27: Train loss: 0.005706334969108092


Epoch 28: Train loss: 0.00564785862757646


Epoch 29: Train loss: 0.005629288476515323


Epoch 30: Train loss: 0.005729566538270484


Epoch 31: Train loss: 0.00578799356627179


Epoch 32: Train loss: 0.005735208047553897


Epoch 33: Train loss: 0.005688068975119831


Epoch 34: Train loss: 0.005614016248666226


Epoch 35: Train loss: 0.005598634214913275


Epoch 36: Train loss: 0.0056140958499956


Epoch 37: Train loss: 0.005595742262146892


Epoch 38: Train loss: 0.005600848146020732


Epoch 39: Train loss: 0.005581310801604327


Epoch 40: Train loss: 0.005576675581408942


Epoch 41: Train loss: 0.005580875124940847


Epoch 42: Train loss: 0.0055766456413697055


Epoch 43: Train loss: 0.005575444607777482


Epoch 44: Train loss: 0.005623326295035634


Epoch 45: Train loss: 0.005674237097078498


Epoch 46: Train loss: 0.0056103289890241746


Epoch 47: Train loss: 0.0055329817188705535


Epoch 48: Train loss: 0.005481103901810786


Epoch 49: Train loss: 0.00548230285280721


Epoch 50: Train loss: 0.005506481591889516


Epoch 51: Train loss: 0.0056261521288530625


Epoch 52: Train loss: 0.0054761372377818565


Epoch 53: Train loss: 0.0054633372027347695


Epoch 54: Train loss: 0.005499559172250806


Epoch 55: Train loss: 0.00554237147753543


Epoch 56: Train loss: 0.005551245536139988


Epoch 57: Train loss: 0.005449077393859625


Epoch 58: Train loss: 0.005373258654900054


Epoch 59: Train loss: 0.005349917547341357


Epoch 60: Train loss: 0.005335770126313288


Epoch 61: Train loss: 0.005355028996362965


Epoch 62: Train loss: 0.005401861839036041


Epoch 63: Train loss: 0.005376352899846561


Epoch 64: Train loss: 0.005362083407198178


Epoch 65: Train loss: 0.005363573903772742


Epoch 66: Train loss: 0.005364898909279641


Epoch 67: Train loss: 0.005350543651729822


Epoch 68: Train loss: 0.005338904052536855


Epoch 69: Train loss: 0.0053516154732317365


Epoch 70: Train loss: 0.005314644390439734


Epoch 71: Train loss: 0.005358039144862522


Epoch 72: Train loss: 0.005325092649721402


Epoch 73: Train loss: 0.005231908297641797


Epoch 74: Train loss: 0.005229404479502997


Epoch 75: Train loss: 0.005233148128745404


Epoch 76: Train loss: 0.005236182322210454


Epoch 77: Train loss: 0.005198297615935828


Epoch 78: Train loss: 0.005232748862831516


Epoch 79: Train loss: 0.005219630790042116


Epoch 80: Train loss: 0.005240845171972475


Epoch 81: Train loss: 0.005237689317065351


Epoch 82: Train loss: 0.005229637790669469


Epoch 83: Train loss: 0.005234424857066032


Epoch 84: Train loss: 0.0051735718050932


Epoch 85: Train loss: 0.005167115242914


Epoch 86: Train loss: 0.005175815079797139


Epoch 87: Train loss: 0.005113677747864673


Epoch 88: Train loss: 0.005134410477500964


Epoch 89: Train loss: 0.0051669862745527895


Epoch 90: Train loss: 0.005147346456911652


Epoch 91: Train loss: 0.005150317868336718


Epoch 92: Train loss: 0.005121785892728478


Epoch 93: Train loss: 0.005114337840573268


Epoch 94: Train loss: 0.005155161042955328


Epoch 95: Train loss: 0.005249772898535779


Epoch 96: Train loss: 0.005092267894205895


Epoch 97: Train loss: 0.005091665576192293


Epoch 98: Train loss: 0.005075403379197133


Epoch 99: Train loss: 0.005018807192669904


Epoch 100: Train loss: 0.00502141490875882


Epoch 101: Train loss: 0.00504965850033183


Epoch 102: Train loss: 0.005093088919455383


Epoch 103: Train loss: 0.005091099611463699


Epoch 104: Train loss: 0.005397963285089491


Epoch 105: Train loss: 0.0052641856880422605


Epoch 106: Train loss: 0.005048872866330946


Epoch 107: Train loss: 0.004965303192271831


Epoch 108: Train loss: 0.004913390085021866


Epoch 109: Train loss: 0.004890763821040697


Epoch 110: Train loss: 0.0048722194507718085


Epoch 111: Train loss: 0.004915029187983972


Epoch 112: Train loss: 0.004966070078947443


Epoch 113: Train loss: 0.004970508376929037


Epoch 114: Train loss: 0.004979134408479675


Epoch 115: Train loss: 0.004942491757584379


Epoch 116: Train loss: 0.004940036340477936


Epoch 117: Train loss: 0.0049544725745440795


Epoch 118: Train loss: 0.0050188877382018465


Epoch 119: Train loss: 0.005033285924768511


Epoch 120: Train loss: 0.005010848370519407


Epoch 121: Train loss: 0.004924836236626861


Epoch 122: Train loss: 0.004909707724731019


Epoch 123: Train loss: 0.004889348327637987


Epoch 124: Train loss: 0.004901974423332734


Epoch 125: Train loss: 0.004909422654817079


Epoch 126: Train loss: 0.0048725179010169935


Epoch 127: Train loss: 0.004846666826608967


Epoch 128: Train loss: 0.004895451820475307


Epoch 129: Train loss: 0.004927607909082732


Epoch 130: Train loss: 0.004923235074161215


Epoch 131: Train loss: 0.004975719564653774


Epoch 132: Train loss: 0.004931229569929394


Epoch 133: Train loss: 0.004952570229293184


Epoch 134: Train loss: 0.0050367330062262555


Epoch 135: Train loss: 0.004953489601215784


Epoch 136: Train loss: 0.0048140859013383695


Epoch 137: Train loss: 0.004753931016324366


Epoch 138: Train loss: 0.004735215507606243


Epoch 139: Train loss: 0.004794292985164422


Epoch 140: Train loss: 0.0048696352594948195


Epoch 141: Train loss: 0.004907230130258075


Epoch 142: Train loss: 0.004918612345577555


Epoch 143: Train loss: 0.004853357237942041


Epoch 144: Train loss: 0.004853311627905103


Epoch 145: Train loss: 0.004839344142361524


Epoch 146: Train loss: 0.004816694385273025


Epoch 147: Train loss: 0.004779449810690068
