In [1]:
# import torch.nn as nn
# import torch
# import torch.functional as F
# import numpy
from fastai.vision.all import *

In [2]:
torch.cuda.set_device(0)
torch.cuda.get_device_name()

'NVIDIA GeForce RTX 3090'

In [3]:
x = torch.rand(8,3,64,64) #single image
x_t = torch.rand(8,5,3,64,64) #multi frames

In [4]:
@delegates(create_cnn_model)
class Encoder(Module):
    def __init__(self, arch=resnet34, n_in=3, weights_file=None, head=True, **kwargs):
        "Encoder based on resnet, if head=False returns the feature map"
        model = create_cnn_model(arch, n_out=1, n_in=n_in, pretrained=True, **kwargs)
        if weights_file is not None: load_model(weights_file, model, opt=None)
        self.body = model[0]
        if head: self.head = model[1]
        else:    self.head = nn.Sequential(*(model[1][0:3]))

    def forward(self, x):
        return self.head(self.body(x))

In [5]:
enc = Encoder(n_in=3,weights_file=None,head=False)

Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /mnt/home/hheat/.cache/torch/hub/checkpoints/resnet34-b627a593.pth


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=87319819.0), HTML(value='')))




In [76]:
enc

Encoder(
  (body): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  

In [7]:
preds = enc(x)
preds.shape

torch.Size([8, 1024])

In [9]:
time_enc = TimeDistributed(enc)

In [10]:
time_enc(x_t).shape

torch.Size([8, 5, 1024])

In [12]:
num_features_model??

In [73]:
class SimpleModel(Module):
    "A simple CNN model"
    def __init__(self, arch=resnet34, weights_file=None, num_classes=30, seq_len=40, debug=False):
        "Create a simple arch based model"
        model = Encoder(arch, 3, weights_file, head=False)
        nf = num_features_model(nn.Sequential(*model.body.children()))*2
        self.encoder = TimeDistributed(model)
#         self.head = nn.Sequential(LinBnDrop(nf,  nf//2, p=0.2, act=nn.ReLU()),
#                                   LinBnDrop(nf//2, num_classes, p=0.05))
        self.attention_layer = nn.Linear(nf, 1)
        self.debug = debug

    def forward(self, x):
        if self.debug:  print(f' input len:   {x.shape[1], x.shape}')
#         x = torch.stack(x, dim=1)
#         if self.debug:  print(f' after stack:   {x.shape}')
        batch_size, seq_length, c, h, w = x.shape
        #x = x.view(batch_size * seq_length, c, h, w)
        x = self.encoder(x)
        #x = x.view(batch_size, seq_length, -1)
        if self.debug:  print(f' encoded shape: {x.shape}')
        attention_w = F.softmax(self.attention_layer(x).squeeze(-1), dim=-1)
        x = torch.sum(attention_w.unsqueeze(-1) * x, dim=1)
        #x = attention_w * x
        if self.debug:  print(f' after attention shape: {x.shape}')
        #x = self.head(x)
        return x

In [74]:
t_model = SimpleModel(debug=True)

In [75]:
out = t_model(x_t)

 input len:   (5, torch.Size([8, 5, 3, 64, 64]))
 encoded shape: torch.Size([8, 5, 1024])
 after attention shape: torch.Size([8, 1024])


In [41]:
out.shape

torch.Size([8, 30])

In [43]:
inp_x = torch.rand(8,5,1024)
a_lin = nn.Linear(1024,1024)
p = a_lin(inp_x)
p.shape

torch.Size([8, 5, 1024])

In [47]:
F.softmax(p,dim=-1).shape

torch.Size([8, 5, 1024])

In [78]:
class LSTM(Module):
    def __init__(self, input_dim, n_hidden, n_layers, bidirectional=False, p=0.5):
        self.lstm = nn.LSTM(input_dim, n_hidden, n_layers, batch_first=True, bidirectional=bidirectional)
        self.drop = nn.Dropout(p)
        self.h = None

    def reset(self):
        self.h = None

    def forward(self, x):
        if (self.h is not None) and (x.shape[0] != self.h[0].shape[1]): #dealing with last batch on valid
#             self.h = [h_[:, :x.shape[0], :] for h_ in self.h]
            self.h = None
        raw, h = self.lstm(x, self.h)
        out = self.drop(raw)
        self.h = [h_.detach() for h_ in h]
        return out, h

In [96]:
class ConvLSTM(Module):
    def __init__(self, arch=resnet34, weights_file=None, num_classes=30, lstm_layers=1, hidden_dim=1024, 
                 bidirectional=True, attention=True, debug=False):
        model = Encoder(arch, 3, weights_file, head=False)
        nf = num_features_model(nn.Sequential(*model.body.children())) * 2
        self.encoder = TimeDistributed(model)
        self.lstm = LSTM(nf, hidden_dim, lstm_layers, bidirectional)
        self.attention = attention
        h_state = 2 * hidden_dim if bidirectional else hidden_dim
        self.attention_layer = nn.Linear(h_state,h_state)
        self.head = nn.Sequential(
            LinBnDrop( (lstm_layers if not attention else 1)*(2 * hidden_dim if bidirectional else hidden_dim), 
                      hidden_dim, p=0.2, act=nn.ReLU()),
            nn.Linear(hidden_dim, num_classes),
        )
        self.debug = debug
        
    def forward(self, x):
        batch_size, seq_length, c, h, w = x.shape
        x = self.encoder(x)
        if self.debug:  print(f' before lstm:   {x.shape}')
        if self.attention:
            attention_w = F.softmax(self.attention_layer(x).squeeze(-1), dim=-1)
            if self.debug: print(f' attention_w: {attention_w.shape}')
            x = attention_w * x
            if self.debug: print(f' after attention: {x.shape}')
        out, (h,c) = self.lstm(x)
        if self.debug:  print(f' after lstm:   {out.shape}')
        return self.head(out)
    
    def reset(self): self.lstm.reset()

In [97]:
model = ConvLSTM(attention=True,bidirectional=False,debug=True)

In [98]:
preds = model(x_t)

 before lstm:   torch.Size([8, 5, 1024])
 attention_w: torch.Size([8, 5, 1024])
 after attention: torch.Size([8, 5, 1024])
 after lstm:   torch.Size([8, 5, 1024])


RuntimeError: running_mean should contain 5 elements not 1024

In [91]:
lstm_layer = LSTM(1024,1024,1)

In [113]:
lstm_layer(torch.rand(8,5,1024))[1][1].shape

torch.Size([1, 8, 1024])

In [106]:
test_upsample = nn.Conv3d(1024,1,kernel_size=(1,3,3),padding=(0,1,1))

In [107]:
test_upsample(torch.rand(8,5,1024))

RuntimeError: Expected 5-dimensional input for 5-dimensional weight [1, 1024, 1, 3, 3], but got 3-dimensional input of size [8, 5, 1024] instead

In [108]:
nn.LSTM??

In [114]:
# credit: https://github.com/CommissarMa/CSRNet-pytorch/blob/master/model.py
from torchvision import models
import collections


class CSRNet(nn.Module):
    def __init__(self, load_weights=False):
        super(CSRNet, self).__init__()
        self.frontend_feat = [64, 64, 'M', 128, 128,
                              'M', 256, 256, 256, 'M', 512, 512, 512]
        self.backend_feat = [512, 512, 512, 256, 128, 64]
        self.frontend = make_layers(self.frontend_feat)
        self.backend = make_layers(
            self.backend_feat, in_channels=512, dilation=True)
        self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
        if not load_weights:
            mod = models.vgg16(pretrained=True)
            self._initialize_weights()
            fsd = collections.OrderedDict()
            # 10 convlution *(weight, bias) = 20 parameters
            for i in range(len(self.frontend.state_dict().items())):
                temp_key = list(self.frontend.state_dict().items())[i][0]
                fsd[temp_key] = list(mod.state_dict().items())[i][1]
            self.frontend.load_state_dict(fsd)

    def forward(self, x):
        x = self.frontend(x)
        x = self.backend(x)
        x = self.output_layer(x)
        x = nn.functional.interpolate(x, scale_factor=8)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


def make_layers(cfg, in_channels=3, batch_norm=False, dilation=False):
    if dilation:
        d_rate = 2
    else:
        d_rate = 1
    layers = []
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3,
                               padding=d_rate, dilation=d_rate)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

In [120]:
csrnet = TimeDistributed(CSRNet())

In [121]:
preds = csrnet(x_t)
preds.shape

torch.Size([8, 5, 1, 64, 64])