In [23]:
import matplotlib.pyplot as plt
from tqdm import tqdm
import os


import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [24]:
# device = torch.device('cuda' if torch.backends.cuda.is_available() else 'cpu')
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device)

mps


In [25]:
os.listdir('./../../../Dataset_Student/unlabeled/video_10000/')

['image_15.png',
 'image_14.png',
 'image_16.png',
 'image_17.png',
 'image_13.png',
 'image_12.png',
 'image_10.png',
 'image_11.png',
 'image_8.png',
 'image_9.png',
 'image_2.png',
 'image_3.png',
 'image_1.png',
 'image_0.png',
 'image_4.png',
 'image_5.png',
 'image_7.png',
 'image_6.png',
 'image_20.png',
 'image_21.png',
 'image_19.png',
 'image_18.png']

In [26]:
frames = []

base_dir = './../../../Dataset_Student/unlabeled/video_10000/'

image_names = [f'image_{i}.png' for i in range(22)]

for file_name in image_names:
    img = plt.imread(base_dir + file_name)
    frames.append(img)
# plt.imshow(frames[3])
frames[3].shape

(160, 240, 3)

In [27]:
class CustomDataset(Dataset):
    def __init__(self, n_videos):
        self.video_idxs = torch.tensor([i for i in range(10000, n_videos+10000)])

    def __len__(self):
        return len(self.video_idxs)

    def __getitem__(self, idx):
#         global net_id
        i = self.video_idxs[idx]
        file_path = f'./../../../Dataset_Student/unlabeled/video_{i}/'
        x = []
        for j in range(11):
            x.append(torch.tensor(plt.imread(file_path+f'image_{j}.png')).permute(2, 0, 1))
        x = torch.stack(x, 0)
        y = []
        for j in range(11, 22):
            y.append(torch.tensor(plt.imread(file_path+f'image_{j}.png')).permute(2, 0, 1))
        y = torch.stack(y, 0)
        return x, y

In [28]:
batch_size = 8

# Create DataLoader
train_dataset = CustomDataset(5)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [29]:
for x, y in train_loader:
    print(x.shape, y.shape)
    break

torch.Size([5, 11, 3, 160, 240]) torch.Size([5, 11, 3, 160, 240])


In [30]:
# [1, 2, 1, 2]

In [31]:
class Encoder(nn.Module):
    def __init__(self, in_channels, hid_channels):
        super().__init__()
        self.inp_enc = nn.Sequential(nn.Conv2d(in_channels, hid_channels, kernel_size=3, stride=1, padding=1),
                                 nn.GroupNorm(2, hid_channels),
                                 nn.LeakyReLU(0.2))
        self.enc = nn.Sequential(nn.Conv2d(hid_channels, hid_channels, kernel_size=3, stride=2, padding=1),
                                 nn.GroupNorm(2, hid_channels),
                                 nn.LeakyReLU(0.2),
                                 nn.Conv2d(hid_channels, hid_channels, kernel_size=3, stride=1, padding=1),
                                 nn.GroupNorm(2, hid_channels),
                                 nn.LeakyReLU(0.2),
                                 nn.Conv2d(hid_channels, hid_channels, kernel_size=3, stride=2, padding=1),
                                 nn.GroupNorm(2, hid_channels),
                                 nn.LeakyReLU(0.2))
    def forward(self, x):
        b, t, c, h, w = x.shape
        x = self.inp_enc(x.view(b*t, c, h, w))
        x = self.inp_enc(x)
        res = x.clone()
        x = self.enc(x)
        return x, res

In [32]:
class GroupConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups, act_norm=False):
        super(GroupConv2d, self).__init__()
        self.act_norm = act_norm
        if in_channels % groups != 0:
            groups = 1
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 
                              padding=padding,groups=groups)
        self.norm = nn.GroupNorm(groups,out_channels)
        self.activate = nn.LeakyReLU(0.2)
    
    def forward(self, x):
        y = self.conv(x)
        if self.act_norm:
            y = self.activate(self.norm(y))
        return y

In [33]:
class Inception(nn.Module):
    def __init__(self, C_in, C_hid, C_out, incep_ker=[3,5,7,11], groups=8):        
        super().__init__()
        self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0)
        layers = []
        for ker in incep_ker:
            layers.append(GroupConv2d(C_hid, C_out, kernel_size=ker, stride=1, padding=ker//2, 
                                      groups=groups, act_norm=True))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        y = 0
        for layer in self.layers:
            y += layer(x)
        return y

In [34]:
class Translator(nn.Module):
    def __init__(self, channel_in, channel_hid, N_T, incep_ker=[3,5,7,11], groups=8):
        super().__init__()

        self.N_T = N_T
        enc_layers = [Inception(channel_in, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)]
        for i in range(1, N_T-1):
            enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
        enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))

        dec_layers = [Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)]
        for i in range(1, N_T-1):
            dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
        dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_in, incep_ker= incep_ker, groups=groups))

        self.enc = nn.Sequential(*enc_layers)
        self.dec = nn.Sequential(*dec_layers)

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.reshape(B, T*C, H, W)

        # Encoder
        skips = []
        z = x
        for i in range(self.N_T):
            z = self.enc[i](z)
            if i < self.N_T - 1:
                skips.append(z)

        # Decoder
        z = self.dec[0](z)
        for i in range(1, self.N_T):
            z = self.dec[i](torch.cat([z, skips[-i]], dim=1))

        y = z.reshape(B, T, C, H, W)
        return y

In [35]:
# [2, 1, 2, 1]
class Decoder(nn.Module):
    def __init__(self, hid_channels, out_channels):
        super().__init__()
        self.dec = nn.Sequential(nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size=3, stride=2, padding=1),
                                 nn.GroupNorm(2, hid_channels),
                                 nn.LeakyReLU(0.2),
                                 nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size=3, stride=1, padding=1),
                                 nn.GroupNorm(2, hid_channels),
                                 nn.LeakyReLU(0.2),
                                 nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size=3, stride=2, padding=1),
                                 nn.GroupNorm(2, hid_channels),
                                 nn.LeakyReLU(0.2))
        self.out_dec = nn.Sequential(nn.ConvTranspose2d(2*hid_channels, hid_channels, kernel_size=3, stride=1, padding=1),
                                 nn.GroupNorm(2, hid_channels),
                                 nn.LeakyReLU(0.2))
        
        self.out = nn.Conv2d(hid_channels, out_channels, 1)
        
    def forward(self, x, enc):
        x = self.dec(x)
        y = self.out_dec(torch.cat([x, enc], dim=1))
        y = self.out(y)
        return y

In [42]:
class SimVP(nn.Module):
    def __init__(self, shape_in, hidden_size=16, translator_size=256, incep_ker=[3,5,7,11], groups=8):
        super().__init__()
        T, C, H, W = shape_in
        self.enc = Encoder(C, hidden_size)
        self.hid = Translator(T*hidden_size, translator_size, 8, incep_ker, groups)
        self.dec = Decoder(hidden_size, C)


    def forward(self, x_raw):
        B, T, C, H, W = x_raw.shape
        x = x_raw.view(B*T, C, H, W)

        embed, skip = self.enc(x)
        _, C_, H_, W_ = embed.shape

        z = embed.view(B, T, C_, H_, W_)
        hid = self.hid(z)
        hid = hid.reshape(B*T, C_, H_, W_)

        Y = self.dec(hid, skip)
        Y = Y.reshape(B, T, C, H, W)
        return Y

In [44]:
model = SimVP(shape_in=(11, 3, 160, 240))
model(x)

ValueError: not enough values to unpack (expected 5, got 4)