In [None]:
import numpy as np 
import torch 
import torch.nn as nn 
import torchvision 
import matplotlib.pyplot as plt

In [None]:
def conv_block(in_f,out1,out2,out3):
    return nn.Sequential(
        nn.Conv3d(in_channels=in_f, out_channels=out1, kernel_size=(1,3,3), stride=1,padding=(0,1,1)),
        nn.BatchNorm3d(num_features=out1),
        nn.ELU(),
        nn.Conv3d(in_channels=out1, out_channels=out2, kernel_size=(1,3,3), stride=1,padding=(0,1,1)),
        nn.BatchNorm3d(num_features=out2),
        nn.ELU(),
        nn.Conv3d(in_channels=out2, out_channels=out3, kernel_size=(1,3,3), stride=1,padding=(0,1,1)),
        nn.BatchNorm3d(num_features=out3),
        nn.ELU(),
        nn.MaxPool3d(kernel_size=(1,2,2),stride=(1,2,2))
    )
    
class CNN_part(nn.Module):
    def __init__(self):
        super(CNN_part, self).__init__()

        self.resize_64 = nn.Upsample(size=(5,32,32), mode='nearest')

        # single layer (64)
        self.con0 = nn.Conv3d(in_channels=3,out_channels=64,kernel_size=(1,3,3), stride=1,padding=(0,1,1))
        self.bn0 = nn.BatchNorm3d(num_features=64)
        self.act0 = nn.ELU()

        #blocks

        self.block1 = conv_block(64,128,196,128)
        self.block2 = conv_block(128,128,196,128)
        self.block3 = conv_block(128,128,196,128)

        #Feature map
        self.con1 = nn.Conv3d(in_channels=384, out_channels=128, kernel_size=(1,3,3), stride=1,padding=(0,1,1)) 
        self.con2 = nn.Conv3d(in_channels=128, out_channels=3, kernel_size=(1,3,3), stride=1,padding=(0,1,1))
        self.con3 = nn.Conv3d(in_channels=3, out_channels=1, kernel_size=(1,3,3), stride=1,padding=(0,1,1))
        #Depth map
        self.con4 = nn.Conv3d(in_channels=384, out_channels=128, kernel_size=(1,3,3), stride=1,padding=(0,1,1)) 
        self.con5 = nn.Conv3d(in_channels=128, out_channels=64, kernel_size=(1,3,3), stride=1,padding=(0,1,1))
        self.con6 = nn.Conv3d(in_channels=64, out_channels=1, kernel_size=(1,3,3), stride=1,padding=(0,1,1))
        

    def forward(self,x):
        x1 = self.con0(x)
        x2 = self.bn0(x1)
        x3 = self.act0(x2)

        x4 = self.block1(x3)
        X1 = self.resize_64(x4)
        x5 = self.block2(x4)
        X2 = self.resize_64(x5)
        print(X2.shape)
        
        x6 = self.block3(x5)
        X3 = self.resize_64(x6)

        inp = torch.cat((X1, X2, X3), 1)
        
        F1 = self.con1(inp)
        F2 = self.con2(F1)
        F3 = self.con3(F2)
        
        D1 = self.con4(inp)
        D2 = self.con5(D1)
        D3 = self.con6(D2)
        
        return D3, F3

In [None]:
class RNN_part(nn.Module):
    def __init__(self):
        super(RNN_part, self).__init__()

        self.hidden = (torch.zeros(1, 1, 100).to(torch.device("cuda")),torch.zeros(1, 1, 100).to(torch.device("cuda")))
        self.LSTM = nn.LSTM(input_size=32*32, hidden_size=100, num_layers=1,batch_first=True)
        self.fc = nn.Linear(in_features=500,out_features=100)

    def forward(self, F):
        F = F.view(1, 5, -1)
        out, self.hidden = self.LSTM(F, self.hidden)
        out = out.view(-1)
        R = self.fc(out)
        R = torch.rfft(R, signal_ndim=1, normalized=False)
        return R

In [None]:
class Anti_spoof_net(nn.Module):

    def __init__(self):
        super(Anti_spoof_net, self).__init__()

        self.threshold = 0.1
        self.CNN = CNN_part()
        self.RNN = RNN_part()

    def forward(self, x):

        D, T = self.CNN(x)
        # Non_rigid_registration_layer
        V = torch.where(D >= self.threshold, torch.ones(1, 1, 5, 32, 32).to(torch.device("cuda")), torch.zeros(1, 1, 5, 32, 32).to(torch.device("cuda")))
        U = T * V

        #F = turning(U, anchors)
        R = self.RNN(U)

        return D, R

In [None]:
device = torch.device("cuda")
#torch.load("/kaggle/input/modell11/model")
#Anti_spoof_net().to(device)
model =torch.load("/kaggle/input/modell11/model")
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=3e-3,betas=(0.9, 0.999),eps=1e-08)

In [None]:
def CNN_part_train(our_net,device, optimizer, load_train, anchor_train, criterion, N = 1):
    total = 0
    for epoch in range(N):
        running_loss = 0.0
        for i, (img,d,r) in enumerate(load_train, 0):
            images, depth,_ = img.to(device),d.to(device),r.to(device)
            if i >= 780:
                # training step
                optimizer.zero_grad()
                net_depth, _ = our_net(images[:,:,0:1,:,:])
                #handle NaN:
                if (torch.norm((net_depth != net_depth).float())==0):  
                    if i == 704:
                        torch.save(our_net,'model') 
                    loss = criterion(depth[:,:,0:1,:,:], net_depth)
                    loss.backward(retain_graph=True)
                    optimizer.step()
                    # compute statistics
                    print(depth.size(0))
                    total += depth.size(0)
                    running_loss += loss.item()
                    print(labels_train[i])
                    print(loss.item())
                    print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / total))

        print('Epoch finished')
    print('Finished Training')

In [None]:
def train_RNN(our_net,device, optimizer, load_train, anchor_train, criterion, N = 1):
    total = 0
    for epoch in range(N):
        running_loss = 0.0
        for i, (img,d,r) in enumerate(load_train, 0):
            images, depth,rppg = img.to(device),d.to(device),r.to(device)
            if i >= 0:
                # training step
                optimizer.zero_grad()
                _, F = our_net(images)
                #handle NaN:
                if (torch.norm((F != F).float())==0):
                    if i == 13:
                        torch.save(our_net,'model')
                    loss = criterion(F[:-1,0], rppg[0,:])
                    loss.backward(retain_graph=True)
                    optimizer.step()
                    # compute statistics
                    print(labels_train[i])
                    print(loss.item())
                    
                    
                    total += depth.size(0)
                    running_loss += loss.item()
                    print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / total))
        print('Epoch finished')
    print('Finished Training')

In [None]:
def train_All(net,device, optimizer, load_train, anchor_train, criterion, N = 10):
    for i in range(N):
        CNN_part_train(net,device, optimizer, load_train, anchor_train, criterion)
        torch.save(net,'mmodel')
        train_RNN(net,device, optimizer, load_train, anchor_train, criterion)
        torch.save(net,'modelll')
        while(True):
            print("lj")
    
outputs = train_All(model,device, optimizer, load_train,anchor_train, criterion)

model = torch.load('model')