In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim as optim
import os

In [2]:
class spectral_attention(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.layer1=nn.Softmax(dim=0)
        
    def forward(self, x):
        
        b, c, w, h, d=x.shape
        x1=x.view(b, c, w*h)
        x1=x1.permute(0, 2, 1)
        x2=x.view(b, c, w*h)
        
        xres1=torch.matmul(x2, x1)
        xres1=self.layer1(xres1)
        
        x2=x.view(b, c, w*h)
        xres2=torch.matmul(xres1, x2)
        xres2=xres2.view(b, c, w, h, d)
        xres2=x+xres2
        
        return xres2

In [3]:
class spatial_attention(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.conv1=ConvBlock(24, 24, (1, 1, 1))
        self.conv2=ConvBlock(24, 24, (1, 1, 1))
        self.conv3=ConvBlock(24, 24, (1, 1, 1))

        self.layer1=nn.Softmax(dim=0)
        
    def forward(self, x):
        
        ba,ch,w,h,de=x.shape

        b=self.conv1(x)
        c=self.conv2(x)
        d=self.conv3(x)
        
        b=b.view(ba, ch, w*h)
        c=c.view(ba, ch, w*h)
        d=d.view(ba, ch, w*h)
        b=b.permute(0, 2, 1)
        
        res1=torch.matmul(b, c)
        res1=self.layer1(res1)
        res1=res1.permute(0, 2, 1)
        
        res=torch.matmul(d, res1)
        res=res.view(ba, ch, w, h, de)
        
        res=res+x
        return res

In [4]:
class ConvBlock(nn.Module):
    
    def __init__(self, inp, out, kernel, stride=(1, 1, 1), padding=(0,0,0), b=True, rel=True):
        super().__init__()
        
        self.conv=nn.Conv3d(inp, out, kernel, stride=stride, padding=padding)
        self.bn=nn.BatchNorm3d(out)
        self.relu=nn.ReLU()
        self.b=b
        self.rel=rel
        
    def forward(self, x):
        
        x=self.conv(x)
        
        if self.b:
            x=self.bn(x)
            
        if self.rel:
            x=self.relu(x)
            
        return x
        

In [5]:
class spectral_features(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.conv1=ConvBlock(1, 24, (1, 1, 5), stride=(1, 1, 2))
        
        self.conv2=ConvBlock(6, 6, (1, 1, 5), padding=(0, 0, 2))
        
        self.conv3=ConvBlock(12, 12, (1, 1, 5), padding=(0, 0, 2))
        
        self.conv4=ConvBlock(18, 12, (1, 1, 5), padding=(0, 0, 2))
        
        self.conv5=ConvBlock(36, 24, (1, 1, 49))
        
        self.att=spectral_attention()
        
        self.bn=nn.BatchNorm3d(24)
        #self.drop=nn.Dropout3d(p=0.0)
        self.relu=nn.ReLU()
        self.pool=nn.AdaptiveAvgPool3d((1, 1, 1))
        
    def forward(self, x):
        
        x=x.permute(0, 4, 1, 2, 3).contiguous()
        
        x=self.conv1(x)
        
        xsplit1=x[:, :6, :, :, :]
        xsplit2=x[:, 6:12, :, :, :]
        xsplit3=x[:, 12:18, :, :, :]
        xsplit4=x[:, 18:24, :, :, :]
        
        xsplit2=self.conv2(xsplit2)
        
        xsplit3=torch.cat((xsplit2, xsplit3), dim=1)
        xsplit3=self.conv3(xsplit3)

        xsplit4=torch.cat((xsplit3, xsplit4), dim=1)
        xsplit4=self.conv4(xsplit4)

        x=torch.cat((xsplit1, xsplit2, xsplit3, xsplit4), dim=1)
        x=self.conv5(x)

        x=self.att(x)
        
        x=self.bn(x)
        x=self.relu(x)
        #x=self.drop(x)
        x=self.pool(x)
        return x
        

In [6]:
class spatial_features(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        
        self.conv1=ConvBlock(1, 24, (1, 1, 102), stride=(1, 1, 1))
        
        self.conv2=ConvBlock(6, 6, (3, 3, 1), padding=(1, 1, 0))
        
        self.conv3=ConvBlock(12, 12, (3, 3, 1), padding=(1, 1, 0))
        
        self.conv4=ConvBlock(18, 12, (3, 3, 1), padding=(1, 1, 0))
        
        self.conv5=ConvBlock(36, 24, (1, 1, 1), padding=(1, 1, 0))
        
        self.att=spatial_attention()
        
        self.bn=nn.BatchNorm3d(24)
        #self.drop=nn.Dropout3d(p=0.0)
        self.relu=nn.ReLU()
        self.pool=nn.AdaptiveAvgPool3d((1, 1, 1))
        
    def forward(self, x):
        
        x=x.permute(0, 4, 1, 2, 3).contiguous()
        
        x=self.conv1(x)
        
        xsplit1=x[:, :6, :, :, :]
        xsplit2=x[:, 6:12, :, :, :]
        xsplit3=x[:, 12:18, :, :, :]
        xsplit4=x[:, 18:24, :, :, :]
        
        xsplit2=self.conv2(xsplit2)        
        xsplit3=torch.cat((xsplit2, xsplit3), dim=1)
        xsplit3=self.conv3(xsplit3)

        xsplit4=torch.cat((xsplit3, xsplit4), dim=1)
        xsplit4=self.conv4(xsplit4)

        x=torch.cat((xsplit1, xsplit2, xsplit3, xsplit4), dim=1)
        x=self.conv5(x)

        x=self.att(x)
        
        #x=self.drop(x)
        x=self.bn(x)
        x=self.relu(x)
        x=self.pool(x)
        return x
        

In [7]:
class HResNet(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.spectral=spectral_features()
        self.spatial=spatial_features()
        self.fusion=nn.Linear(48, 48)
        self.out=nn.Linear(48, 9)
        self.softmax=nn.Softmax(dim=-1)
        
    def forward(self, x):
        
        x1=x
        
        x=self.spectral(x)
        x1=self.spatial(x1)
        
        x=torch.cat((x, x1), dim=1)
        x=x.view(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3])
        
        x=self.fusion(x)
        x=self.out(x)
        x=self.softmax(x)
        
        return x