In [87]:
import PIL
import time, json
import torch
import torchvision
import torch.nn.functional as F
from einops import rearrange
from torch import nn
import torch.nn.init as init
from einops import rearrange, repeat
import collections
import torch.nn as nn


def _weights_init(m):
    classname = m.__class__.__name__
    #print(classname)
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d):
        init.kaiming_normal_(m.weight)

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

# 等于 PreNorm
class LayerNormalize(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


# 等于 FeedForward
class MLP_Block(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):

    def __init__(self, dim, heads, dim_heads, dropout):
        super().__init__()
        inner_dim = dim_heads * heads
        self.heads = heads
        self.scale = dim_heads ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x, mask = None):
        # x:[b,n,dim]
        b, n, _, h = *x.shape, self.heads

        # get qkv tuple:([b,n,head_num*head_dim],[...],[...])
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        # split q,k,v from [b,n,head_num*head_dim] -> [b,head_num,n,head_dim]
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        # transpose(k) * q / sqrt(head_dim) -> [b,head_num,n,n]
        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        mask_value = -torch.finfo(dots.dtype).max

        # mask value: -inf
        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, mask_value)
            del mask

        # softmax normalization -> attention matrix
        attn = dots.softmax(dim=-1)
        # value * attention matrix -> output
        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        # cat all output -> [b, n, head_num*head_dim]
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_heads, mlp_dim, dropout):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(LayerNormalize(dim, Attention(dim, heads=heads, dim_heads=dim_heads, dropout=dropout))),
                Residual(LayerNormalize(dim, MLP_Block(dim, mlp_dim, dropout=dropout)))
            ]))

    def forward(self, x, mask=None):
        print('i am here Transformer')
        for attention, mlp in self.layers:
            x = attention(x, mask=mask)  # go to attention
            x = mlp(x)  # go to MLP_Block
        return x



class SE(nn.Module):

    def __init__(self, in_chnls, ratio):
        super(SE, self).__init__()
        self.squeeze = nn.AdaptiveAvgPool2d((1, 1))
        self.compress = nn.Conv2d(in_chnls, in_chnls//ratio, 1, 1, 0)
        self.excitation = nn.Conv2d(in_chnls//ratio, in_chnls, 1, 1, 0)

    def forward(self, x):
        out = self.squeeze(x)
        out = self.compress(out)
        out = F.relu(out)
        out = self.excitation(out)
        return torch.sigmoid(out) # 2023.09.21 运行这里报错了 原来是 F.sigmoid(out)



class HSINet(nn.Module):
    def __init__(self, params):
        super(HSINet, self).__init__()
        self.params = params
        net_params = params['net']
        data_params = params['data']

        num_classes = data_params.get("num_classes", 16)
        patch_size = data_params.get("patch_size", 13)
        self.spectral_size = data_params.get("spectral_size", 200)

        depth = net_params.get("depth", 1)
        heads = net_params.get("heads", 8)
        mlp_dim = net_params.get("mlp_dim", 8)
        dropout = net_params.get("dropout", 0)
        conv2d_out = 64
        dim = net_params.get("dim", 64)
        dim_heads = dim
        mlp_head_dim = dim
        
        image_size = patch_size * patch_size

        self.pixel_patch_embedding = nn.Linear(conv2d_out, dim)

        self.local_trans_pixel = Transformer(dim=dim, depth=depth, heads=heads, dim_heads=dim_heads, mlp_dim=mlp_dim, dropout=dropout)
        self.new_image_size = image_size
        self.pixel_pos_embedding = nn.Parameter(torch.randn(1, self.new_image_size+1, dim))
        self.pixel_pos_scale = nn.Parameter(torch.ones(1) * 0.01)

        self.conv2d_features = nn.Sequential(
            nn.Conv2d(in_channels=self.spectral_size, out_channels=conv2d_out, kernel_size=(3, 3), padding=(1,1)),
            nn.BatchNorm2d(conv2d_out),
            nn.ReLU(),
            # featuremap 是在这之后加一层channel上的压缩
            # nn.Conv2d(in_channels=conv2d_out,out_channels=dim,kernel_size=3,padding=1),
            # nn.BatchNorm2d(dim),
            # nn.ReLU()
        )

#         self.senet = SE(conv2d_out, 5)

        self.cls_token_pixel = nn.Parameter(torch.randn(1, 1, dim))
        self.to_latent_pixel = nn.Identity()

        self.mlp_head =nn.Linear(dim, num_classes)
        torch.nn.init.xavier_uniform_(self.mlp_head.weight)
        torch.nn.init.normal_(self.mlp_head.bias, std=1e-6)
        self.dropout = nn.Dropout(0.1)

        linear_dim = dim * 2
        self.classifier_mlp = nn.Sequential(
            nn.Linear(dim, linear_dim),
            nn.BatchNorm1d(linear_dim),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(linear_dim, num_classes),
        )

    def encoder_block(self, x):
        '''
        x: (batch, s, w, h), s=spectral, w=weigth, h=height
        '''
        print(f'x.shape1 is {x.shape}')
        x_pixel = x 

        b, s, w, h = x_pixel.shape
        img = w * h
        x_pixel = self.conv2d_features(x_pixel)
        print(f'x.shape2 is {x_pixel.shape}')


#         scale = self.senet(x_pixel)
#         print(f'x.shape3 is {scale.shape}')
        # print('scale shape is ', scale.shape)
        # print('pixel shape is ', x_pixel.shape)
        # x_pixel = x_pixel * scale#(batch, image_size, dim)

        #1. reshape
        x_pixel = rearrange(x_pixel, 'b s w h-> b (w h) s') # (batch, w*h, s)
        print(f'x.shape4 is {x_pixel.shape}')

        #2. patch_embedding
        # x_pixel = self.pixel_patch_embedding(x_pixel)

        #3. local transformer
        cls_tokens_pixel = self.cls_token_pixel.expand(x_pixel.shape[0], -1, -1)
        print(f'x.shape5 is {cls_tokens_pixel.shape}')
        x_pixel = torch.cat((cls_tokens_pixel, x_pixel), dim = 1) #[b,image+1,dim]
        print(f'x.shape6 is {x_pixel.shape}')
        x_pixel = x_pixel + self.pixel_pos_embedding[:,:] * self.pixel_pos_scale
        print(f'x.shape7 is {x_pixel.shape}')
        # x_pixel = x_pixel + self.pixel_pos_embedding[:,:] 
        # x_pixel = self.dropout(x_pixel)

        x_pixel = self.local_trans_pixel(x_pixel) #(batch, image_size+1, dim)
        print(f'x.shape8 is {x_pixel.shape}')

#         out = (torch.bmm(out, torch.transpose(out, 1, 2)) / feature_size).view(batch_size, -1) #b, 4096
        
        logit_pixel = self.to_latent_pixel(x_pixel[:,0])
        print(f'x.shape9 is {logit_pixel.shape}')
        
    
        
        logit_x = logit_pixel 
        reduce_x = torch.mean(x_pixel, dim=1)
        
        return logit_x, reduce_x

    def forward(self, x,left=None,right=None):
        '''
        x: (batch, s, w, h), s=spectral, w=weigth, h=height

        '''
        logit_x, _ = self.encoder_block(x)
#         print(f'logit_x shape is {logit_x.shape}')
        mean_left, mean_right = None, None
        if left is not None and right is not None:
            _, mean_left = self.encoder_block(left)
            _, mean_right = self.encoder_block(right)

        # return  self.mlp_head(logit_x), mean_left, mean_right 
        return  self.classifier_mlp(logit_x), mean_left, mean_right 

In [89]:
path_param = './params/indian_diffusion.json'
with open(path_param, 'r') as fin:
    param = json.loads(fin.read())
net = HSINet(param)
# x = torch.randn(4, 1,1200,13,13)
# y = net(x) 
# print(y.shape)

# from torchsummary import summary
# summary(net.cuda(), (1200, 13, 13), batch_size=128) # 1：batch_size 3:图片的通道数 224: 图片的高宽

import torchinfo
torchinfo.summary(net.cuda(), (1200, 13, 13), batch_dim = 0, col_names = ("input_size", "output_size",  "kernel_size"), verbose = 0)

x.shape1 is torch.Size([1, 1200, 13, 13])
x.shape2 is torch.Size([1, 64, 13, 13])
x.shape4 is torch.Size([1, 169, 64])
x.shape5 is torch.Size([1, 1, 64])
x.shape6 is torch.Size([1, 170, 64])
x.shape7 is torch.Size([1, 170, 64])
i am here Transformer
x.shape8 is torch.Size([1, 170, 64])
x.shape9 is torch.Size([1, 64])


Layer (type:depth-idx)                                  Input Shape               Output Shape              Kernel Shape
HSINet                                                  [1, 1200, 13, 13]         [1, 16]                   --
├─Sequential: 1-1                                       [1, 1200, 13, 13]         [1, 64, 13, 13]           --
│    └─Conv2d: 2-1                                      [1, 1200, 13, 13]         [1, 64, 13, 13]           [3, 3]
│    └─BatchNorm2d: 2-2                                 [1, 64, 13, 13]           [1, 64, 13, 13]           --
│    └─ReLU: 2-3                                        [1, 64, 13, 13]           [1, 64, 13, 13]           --
├─Transformer: 1-2                                      [1, 170, 64]              [1, 170, 64]              --
│    └─ModuleList: 2-4                                  --                        --                        --
│    │    └─ModuleList: 3-1                             --                        --              

In [17]:
print(net)

HSINet(
  (pixel_patch_embedding): Linear(in_features=64, out_features=64, bias=True)
  (local_trans_pixel): Transformer(
    (layers): ModuleList(
      (0): ModuleList(
        (0): Residual(
          (fn): LayerNormalize(
            (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
            (fn): Attention(
              (to_qkv): Linear(in_features=64, out_features=3840, bias=False)
              (to_out): Sequential(
                (0): Linear(in_features=1280, out_features=64, bias=True)
                (1): Dropout(p=0, inplace=False)
              )
            )
          )
        )
        (1): Residual(
          (fn): LayerNormalize(
            (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
            (fn): MLP_Block(
              (net): Sequential(
                (0): Linear(in_features=64, out_features=8, bias=True)
                (1): GELU(approximate=none)
                (2): Dropout(p=0, inplace=False)
                (3): Linear(i

In [18]:
from torchsummary import summary
summary(net.cuda(), (1200, 13, 13), batch_size=128) # 1：batch_size 3:图片的通道数 224: 图片的高宽

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [128, 64, 13, 13]         691,264
       BatchNorm2d-2          [128, 64, 13, 13]             128
              ReLU-3          [128, 64, 13, 13]               0
 AdaptiveAvgPool2d-4            [128, 64, 1, 1]               0
            Conv2d-5            [128, 12, 1, 1]             780
            Conv2d-6            [128, 64, 1, 1]             832
                SE-7            [128, 64, 1, 1]               0
         LayerNorm-8             [128, 170, 64]             128
            Linear-9           [128, 170, 3840]         245,760
           Linear-10             [128, 170, 64]          81,984
          Dropout-11             [128, 170, 64]               0
        Attention-12             [128, 170, 64]               0
   LayerNormalize-13             [128, 170, 64]               0
         Residual-14             [128, 

In [76]:
path_param = './params/indian_diffusion.json'
with open(path_param, 'r') as fin:
    param = json.loads(fin.read())
net2 = HSINet(param).cuda()
x = torch.randn(128,1200,13,13).cuda()
y = net2(x)
print(y[0].shape)

x.shape1 is torch.Size([128, 1200, 13, 13])
x.shape2 is torch.Size([128, 64, 13, 13])
x.shape3 is torch.Size([128, 64, 1, 1])
x.shape4 is torch.Size([128, 169, 64])
x.shape5 is torch.Size([128, 1, 64])
x.shape6 is torch.Size([128, 170, 64])
x.shape7 is torch.Size([128, 170, 64])
i am here Transformer
x.shape8 is torch.Size([128, 170, 64])
type is torch.Size([128, 64])
x.shape9 is torch.Size([128, 64])
torch.Size([128, 16])
