In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
def conv3x3x3(in_planes, out_planes, stride=1, groups = 1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, groups = 1, padding=1, bias=False)

In [5]:
def conv1x1x1(in_planes, out_planes, stride=1, groups = 1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, groups = 1, padding=0, bias=False)

### Non-local

In [6]:
class NonLocal(nn.Module):
    def __init__(self, in_channel):
        super(NonLocal,self).__init__()
        self.in_channels = in_channel
        self.inter_channels = in_channel // 2
        self.g = nn.Sequential(nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels,
                                         kernel_size=1, stride=1, padding=0),
                               nn.BatchNorm3d(self.in_channels))
        
        self.theta = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels,
                         kernel_size=1, stride=1, padding=0)
        self.phi = nn.Conv3d(in_channels=self.in_channels, out_channels=self.inter_channels,
                         kernel_size=1, stride=1, padding=0)
        
        self.W = nn.Sequential(nn.Conv3d(in_channels=self.inter_channels, out_channels=self.in_channels,
                                         kernel_size=1, stride=1, padding=0),
                               nn.BatchNorm3d(self.in_channels))
        
    def forward(x):
        '''
        param x: (b, c, t, h, w)
        '''
        b, c, t, h, w = x.shape[0]
        thetax = np.transpose(self.theta(x), (0,2,3,4,1)).reshape(b, t*h*w, c//2)
        phix = np.transpose(self.phi(x), (0,2,3,4,1)).reshape(b, t*h*w, c//2).permute(0, 2, 1)
        gx = np.transpose(self.phi(x), (0,2,3,4,1)).reshape(b, t*h*w, c//2)
        mat1 = torch.matmul(thetax, phix)
        soft = F.softmax(f, dim=-1)
        mat2 = torch.matmul(mat1, gx).permute(0, 2, 1).contiguous().view(batch_size,
                                                                         self.inter_channels, *x.size()[2:])
        Wy = self.W(mat2)
        return W_y + x     

### SlowFast

In [16]:
class FastBottleneck(nn.Module):
    def __init__(self, inChannel, stride = 1, isDownSample=False, instack=True, first=False):
        super(Bottleneck, self).__init__()
        self.expansion = 4
        if first:
            reduce = 1
        else:
            if instack==False:
                inChannel /= 2
                reduce = 2
            else:
                inChannel /= 4 
                reduce = 4
        self.isDownSample = isDownSample
        self.conv1 = nn.Conv3d(inChannel * reduce, inChannel, kernel_size=(3,1,1), padding=(1,0,0))
        self.norm1 = nn.BatchNorm3d(inChannel)

        self.conv2 = nn.Conv3d(inChannel, inChannel, kernel_size=(1,3,3), stride=stride, padding=(0,1,1))
        self.norm2 = nn.BatchNorm3d(inChannel)
        
        self.conv3 = nn.Conv3d(inChannel, inChannel * self.expansion, kernel_size=1)
        self.norm3 = nn.BatchNorm3d(inChannel * self.expansion)
        self.relu1 = nn.ReLU(inplace=True)
        
        if isDownSample:
            self.downconv = nn.Conv3d(inChannel, inChannel * self.expansion, stride=stride)
            self.downnorm = nn.BatchNorm3d(inChannel * self.expansion)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.norm2(out)
        out = self.relu1(out)

        out = self.conv3(out)
        out = self.norm3(out)

        if self.isDownSample:
            identity = self.downconv(identity)
            identity = self.downnorm(identity)
        out += identity
        out = self.relu1(out)

        return out

In [17]:
class SlowBottleneck(nn.Module):
    def __init__(self, inChannel, stride = 1, isDownSample=False, instack=True):
        super(Bottleneck, self).__init__()
        self.expansion = 4
        self.alpha = 8
        self._beta = 1 / 8
        if first:
            reduce = 1
        else:
            if instack==False:
                inChannel /= 2
                reduce = 2
            else:
                inChannel /= 4 
                reduce = 4
        self.isDownSample = isDownSample
        self.conv1 = nn.Conv3d(inChannel * reduce, inChannel, kernel_size=1)
        self.norm1 = nn.BatchNorm3d(inChannel)

        self.conv2 = nn.Conv3d(inChannel, inChannel, kernel_size=(1,3,3), stride=stride, padding=(0,1,1))
        self.norm2 = nn.BatchNorm3d(inChannel)
        
        self.conv3 = nn.Conv3d(inChannel, inChannel * self.expansion, kernel_size=1)
        self.norm3 = nn.BatchNorm3d(inChannel * self.expansion)
        self.relu1 = nn.ReLU(inplace=True)
        
        if isDownSample:
            self.downconv = nn.Conv3d(inChannel, inChannel * self.expansion, stride=stride)
            self.downnorm = nn.BatchNorm3d(inChannel * self.expansion)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.norm2(out)
        out = self.relu1(out)

        out = self.conv3(out)
        out = self.norm3(out)

        if self.isDownSample:
            identity = self.downconv(identity)
            identity = self.downnorm(identity)
        out += identity
        out = self.relu1(out)

        return out

In [19]:
class SlowFast(nn.Module):
    def __self__(self, li, initial): #initial=[8,64]
        super(resnext, self).__init__()
        self.expansion = 2
        self.fastInitialC = initial[0]
        self.fastC = []
        self.fastnet = [[
            nn.Conv3d(3, self.fastInitialC, (5,7,7), (1,2,2)),
            nn.MaxPool3d((1,3,3), strides = (1,2,2)),]
        ]
        self.slowInitialC = initial[1]
        self.slowC = []
        self.slownet = [[
            nn.Conv3d(3, self.slowInitialC, (1,7,7), (1,2,2)),
            nn.MaxPool3d((1,3,3), strides = (1,2,2)),]
        ]
        for idx, blockLen in enumerate(li):
            self.fastC.append(self.fastInitialC)
            new_fast_stage = []
            if idx == 0:
                new_fast_stage.append(FastBottleneck(self.self.fastInitialC, instack=False, first=False))
            else:
                new_fast_stage.append(FastBottleneck(self.self.fastInitialC, instack=False))
            new_fast_stage.extend([FastBottleneck(self.self.fastInitialC) for i in range(blockLen - 2)])
            new_fast_stage.append(FastBottleneck(self, self.fastInitialC, stride = (1, 2, 2), isDownSample=True))
            self.fastnet.append(new_fast_stage)
            self.fastInitialC *= 4
            
        for idx, blockLen in enumerate(li):
            self.slowC.append(self.slowInitialC)
            new_slow_stage = []
            if idx == 0 or idx == 1:
                if idx == 0:
                    new_slow_stage.append(SlowBottleneck(self.self.slowInitialC, first=False))
                else:
                    new_slow_stage.append(SlowBottleneck(self.self.slowInitialC, instack=False))
                new_slow_stage.extend([SlowBottleneck(self.self.slowInitialC) for i in range(blockLen - 2)])
                new_slow_stage.append(SlowBottleneck(self, self.slowInitialC, stride = (1, 2, 2), isDownSample=True))
                self.slowInitialC *= 4
            elif idx == 2 or idx == 3:
                new_slow_stage.append(FastBottleneck(self.self.fastInitialC, instack=False))
                new_slow_stage.extend([FastBottleneck(self.self.fastInitialC) for i in range(blockLen - 2)])
                new_slow_stage.append(FastBottleneck(self, self.fastInitialC, stride = (1, 2, 2), isDownSample=True))
                self.slowInitialC *= 4
            elf.slownet.append(new_slow_stage)
            
            lateral = [nn.Conv3d(self.fastC[i], 2*self._beta*self.fastC[i], 
                                 kernel_size=(5,1,1), stride=(self.alpha, 1, 1)) 
                       for i in range(len(li))][:-1]
        
    def forward(fast_data, slow_data):
        for i in range(len(self.fastnet)):
            partfastnet = nn.Sequential(*self.fastnet[i])
            partslownet = nn.Sequential(*self.slownet[i])
            fast_data = fasthead(fast_data)
            slow_data = slowhead(slow_data)
            if i != len(self.fastnet) - 1:
                fast_merge = lateral[i](fast_data)
                slow_data += fast_data
        
        

### Resnet

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=False):
        super(BasicBlock, self).__init__()
        self.isDownSample = downsample
        self.conv1 = conv3x3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes)
        self.bn2 = nn.BatchNorm3d(planes)
        
        if self.isDownSample:
            self.downconv = conv1x1x1(inplanes, planes, stride)
            self.downnorm = nn.BatchNorm3d(planes)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.isDownSample:
            residual = self.downconv(residual)
            residual = self.downnorm(residual)

        out += residual
        out = self.relu(out)

        return ou

In [14]:
class resnet10(nn.Module):
    def __init__(self, num_classes):
        super(resnet10, self).__init__()
        self.conv1 = nn.Conv3d(3, 8, kernel_size=7,stride=(1, 2, 2),padding=(3, 3, 3),bias=False)
        self.bn1 = nn.BatchNorm3d(8)
        self.relu1 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
        
        self.block2 = BasicBlock(8, 16, stride=2, downsample=True)
        self.block3 = BasicBlock(16, 32, stride=2, downsample=True)
        self.block4 = BasicBlock(32, 64, stride=2, downsample=True)
        self.block5 = BasicBlock(64, 128, stride=2, downsample=True)
        
        self.avgpool = nn.AvgPool3d((last_duration, last_size, last_size), stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        
    def forward(x):
        out = self.conv1(x)
        out = self.bn1(x)
        out = self.relu1(x)
        out = self.maxpool1(x)
        
        out = self.block2(x)
        out = self.block3(x)
        out = self.block4(x)
        out = self.block5(x)
        out = self.avgpool(x)
        out = self.fc(x)

In [1]:
class decider():
    pass

In [2]:
class classifier():
    pass

In [None]:
class SlowFast():
    def __init__(self, ):
        

In [18]:
for idx, x in enumerate([1,2,3]):
    print(idx, x)

0 1
1 2
2 3
