In [1]:
!export CUDA_VISIBLE_DEVICES=1
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.models import convnext_large, ConvNeXt_Large_Weights
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import timm
from PIL import Image
import glob
import csv
import random
import numpy as np
import os
import pandas as pd


In [2]:
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(2022)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
ConvNext = convnext_large(weights=ConvNeXt_Large_Weights.DEFAULT)

In [4]:
class LKA(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
        self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
        self.conv1 = nn.Conv2d(dim, dim, 1)


    def forward(self, x):
        u = x.clone()        
        attn = self.conv0(x)
        attn = self.conv_spatial(attn)
        attn = self.conv1(attn)

        return u * attn


class Attention(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        self.proj_1 = nn.Conv2d(d_model, d_model, 1)
        self.activation = nn.GELU()
        self.spatial_gating_unit = LKA(d_model)
        self.proj_2 = nn.Conv2d(d_model, d_model, 1)

    def forward(self, x):
        shorcut = x.clone()
        x = self.proj_1(x)
        x = self.activation(x)
        x = self.spatial_gating_unit(x)
        x = self.proj_2(x)
        x = x + shorcut
        return x

In [5]:
class ModelConfig:
    classes = 3
    att_dim = 1536
    in_size = 1536
    out_size = 3
    dropout = 0.5
    def __init__(self, **kwargs):
        for k,v in kwargs.items():
            setattr(self, k, v)

In [6]:
# class InstanceClassifier(nn.Module):
#     def __init__(self,config):
#         super(InstanceClassifier, self).__init__()
#         # self.FE = FeatureExtractor
#         self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
#         self.classifier = nn.Linear(config.in_size, config.classes)
#     def forward(self,features):
#         # features = self.FE(x)

#         h = self.avgpool(features)
       
#         feats = h.view(h.size(0), -1)
    
#         C = self.classifier(feats)

#         return feats, C

In [7]:
class InstanceClassifier(nn.Module):
    def __init__(self,config):
        super(InstanceClassifier, self).__init__()
  
        self.classifier = nn.Linear(config.in_size, config.classes)
    def forward(self,features):
        # features = self.FE(x)

   
        C = self.classifier(features)

        return features, C

In [7]:
vis_att = Attention(1536) 
vis_att.to(device)
vis_att(torch.randn(1,1536,7,7).to(device)).shape

torch.Size([1, 1536, 7, 7])

In [7]:
class AttDual(nn.Module):
    def __init__(self,config):
        super(AttDual,self).__init__()
        self.config = config
   
        self.i_classifier = InstanceClassifier(config)

        self.key = nn.Sequential(nn.Linear(config.in_size, config.in_size),
        nn.Dropout(config.dropout),
        nn.LayerNorm(config.in_size),
        nn.GELU())  #in_size=1536 after average pooling of features

        self.query = nn.Sequential(nn.Linear(config.in_size, config.in_size),
        nn.Dropout(config.dropout),
        nn.LayerNorm(config.in_size),
        nn.GELU())

        self.value = nn.Sequential(nn.Linear(config.in_size, config.in_size),
        nn.Dropout(config.dropout),
        nn.LayerNorm(config.in_size),
        nn.GELU())

        self.head = nn.Conv1d(config.out_size, config.out_size, kernel_size=config.in_size)

        self.apply(self.init_weights)


    def init_weights(self, m):
       
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight)
            nn.init.constant_(m.bias, 0)
            
        elif isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
            nn.init.kaiming_normal_(m.weight)
            nn.init.constant_(m.bias, 0)

        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
    
    def forward(self,features):
     
      
        # print(f"features shape after visual spatial attention: {features.shape}")

        features,c = self.i_classifier(features) #classifier output is features after pooling/reshape to (B,K) and instance classes

        K = self.key(features)
        # print(f"Key after applying self.key: {K.shape}")
        V = self.value(K)  #B * K , unsorted
        Q = self.query(K)# The QK("query-key")circuits controls which features/tokens the head prefers to attend to.
        # print(f"Query after applying self.query to Key and no view(B,-1): {Q.shape}")

        # handle multiple classes without for loop
        _, m_indices = torch.sort(c, 0, descending=True) # sort class scores along the instance dimension, m_indices in shape N x C
        # print(f"m_indices shape: {m_indices.shape}")
        m_feats = torch.index_select(K, dim=0, index=m_indices[0, :]) # select critical instances, m_feats in shape C x K 
        # print(f"m_feats shape: {m_feats.shape}")
        q_max = self.query(m_feats) # compute queries of critical instances, q_max in shape C x Q
        # print(f"q_max shape: {q_max.shape}")
        A = torch.mm(Q, q_max.transpose(0, 1)) # compute inner product of Q to each entry of q_max, A in shape N x C, each column contains unnormalized attention scores
        # print(f"A score shape: {A.shape}")
        A = F.softmax( A / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)), 0) # normalize attention scores, A in shape N x C, 
        # print(f"A softmax normalized score shape: {A.shape}")
        B = torch.mm(A.transpose(0, 1), V) # compute bag representation, B in shape C x V
        # print(f"results after score multiply with Value shape: {B.shape}")
                
        B = B.view(1, B.shape[0], B.shape[1]) # 1 x C x V
        # print(f"B unsqueezed shape: {B.shape}")
        C = self.head(B) # 1 x C x 1
        # print(f"Class shape after 1 d convolution: {C.shape}")
        C = C.view(1, -1)
        # print(f"Class shape after reshape: {C.shape}")
        return C, A, B 


In [None]:
vis_att = Attention(config.att_dim) 

In [None]:
class VisAttDual(nn.Module):
    def __init__(self,config):
        super(VisAttDual,self).__init__()
        self.config = config
        self.vis_att = Attention(config.att_dim)   #visual att_dim features channel size 1536
        # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.i_classifier = InstanceClassifier(config)

        self.key = nn.Sequential(nn.Linear(config.in_size, config.in_size),
        nn.Dropout(config.dropout),
        nn.LayerNorm(config.in_size),
        nn.GELU())  #in_size=1536 after average pooling of features

        self.query = nn.Sequential(nn.Linear(config.in_size, config.in_size),
        nn.Dropout(config.dropout),
        nn.LayerNorm(config.in_size),
        nn.GELU())

        self.value = nn.Sequential(nn.Linear(config.in_size, config.in_size),
        nn.Dropout(config.dropout),
        nn.LayerNorm(config.in_size),
        nn.GELU())

        self.head = nn.Conv1d(config.out_size, config.out_size, kernel_size=config.in_size)

        self.apply(self.init_weights)


    def init_weights(self, m):
       
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight)
            nn.init.constant_(m.bias, 0)
            
        elif isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
            nn.init.kaiming_normal_(m.weight)
            nn.init.constant_(m.bias, 0)

        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
    
    def forward(self,features):
        features = features.view(features.size(0), 1536,7,7)
        features = self.vis_att(features) #apply visual spactial attention to features
        # print(f"features shape after visual spatial attention: {features.shape}")

        features,c = self.i_classifier(features) #classifier output is features after pooling/reshape to (B,K) and instance classes

        K = self.key(features)
        # print(f"Key after applying self.key: {K.shape}")
        V = self.value(K)  #B * K , unsorted
        Q = self.query(K)# The QK("query-key")circuits controls which features/tokens the head prefers to attend to.
        # print(f"Query after applying self.query to Key and no view(B,-1): {Q.shape}")

        # handle multiple classes without for loop
        _, m_indices = torch.sort(c, 0, descending=True) # sort class scores along the instance dimension, m_indices in shape N x C
        # print(f"m_indices shape: {m_indices.shape}")
        m_feats = torch.index_select(K, dim=0, index=m_indices[0, :]) # select critical instances, m_feats in shape C x K 
        # print(f"m_feats shape: {m_feats.shape}")
        q_max = self.query(m_feats) # compute queries of critical instances, q_max in shape C x Q
        # print(f"q_max shape: {q_max.shape}")
        A = torch.mm(Q, q_max.transpose(0, 1)) # compute inner product of Q to each entry of q_max, A in shape N x C, each column contains unnormalized attention scores
        # print(f"A score shape: {A.shape}")
        A = F.softmax( A / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)), 0) # normalize attention scores, A in shape N x C, 
        # print(f"A softmax normalized score shape: {A.shape}")
        B = torch.mm(A.transpose(0, 1), V) # compute bag representation, B in shape C x V
        # print(f"results after score multiply with Value shape: {B.shape}")
                
        B = B.view(1, B.shape[0], B.shape[1]) # 1 x C x V
        # print(f"B unsqueezed shape: {B.shape}")
        C = self.head(B) # 1 x C x 1
        # print(f"Class shape after 1 d convolution: {C.shape}")
        C = C.view(1, -1)
        # print(f"Class shape after reshape: {C.shape}")
        return C, A, B 

In [None]:
class DSMILNet(nn.Module):
    def __init__(self, i_classifier, b_classifier):
        super(DSMILNet, self).__init__()
        self.i_classifier = i_classifier
        self.b_classifier = b_classifier
        
    def forward(self, x):
        feats, classes = self.i_classifier(x)
        prediction_bag, A, B = self.b_classifier(feats, classes)
        
        return classes, prediction_bag, A, B

In [9]:
torch.arange(0, 20).reshape(4,5)[1,:]

tensor([5, 6, 7, 8, 9])

In [8]:
device = torch.cuda.current_device()
config = ModelConfig(classes=3, att_dim=1536, in_size=1536, out_size=3)
f = torch.rand(128,1536 * 7*7).to(device)

m = VisAttDual(config).to(device)
c,a,b = m(f)
print(f"Bag class {c} shape: {c.shape}")
print(f"Bag attention shape: {a.shape}")
print(f"b shape: {b.shape}")

Bag class tensor([[ 0.1915, -0.5536, -0.6501]], device='cuda:0', grad_fn=<ViewBackward0>) shape: torch.Size([1, 3])
Bag attention shape: torch.Size([128, 3])
b shape: torch.Size([1, 3, 1536])


In [16]:
class VisAtt(nn.Module):
    def __init__(self, num_classes: int, repr_length = 1536, dimension = 1000, att=1,dropout=0.5):
        super(VisAtt, self).__init__()
        # self.features = features
        self.vis_att = Attention(dimension)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # self.method = method
   
        self.repr_length = repr_length # size of representation per tile
        self.D = dimension                               # N = batch size
        self.att = att

        self.attention = nn.Sequential(             # N * repr_length
            nn.Linear(self.repr_length, self.D),    # N * D
            nn.LayerNorm(self.D, eps=1e-06, elementwise_affine=True),
            nn.GELU(),                              # N * D
            nn.Dropout(dropout),
            nn.Linear(self.D, self.att)             # N * att
        )

        self.classifier = nn.Sequential(
            nn.Linear(self.repr_length*self.att, num_classes),
            nn.Sigmoid()
        )

        self.attention.apply(self._init_weights)
        self.classifier.apply(self._init_weights)
        self.vis_att.apply(self._init_weights)

    def _init_weights(self, m):
       
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight)
            nn.init.constant_(m.bias, 0)
            
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight)
            nn.init.constant_(m.bias, 0)

        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)


    def forward(self,features):

        features = features.view(features.size(0), 1536,7,7)    # N * repr_length * 7 * 7  size after ConvNext.features(batched tiles)

        h = self.vis_att(features)
        h = self.avgpool(h)
       
        h = h.view(h.size(0), -1)

        A = self.attention(h)       # N * att
  
        A = torch.transpose(A, 1,0)                     # att * N
     
        A = F.softmax(A, dim=1)                         # softmax over N
  
        M = torch.mm(A, h)   # att * repr_length
        out_prob = self.classifier(M)

        y_hat = torch.ge(out_prob, 0.5).float()
        return out_prob, A



In [18]:

print(f.shape)
fv = f.view(f.size(0), 1536, 7, 7)
print(fv.shape)
m = VisAtt(3)
m(f)

torch.Size([10, 75264])
torch.Size([10, 1536, 7, 7])


(tensor([[0.4681, 0.4026, 0.5655]], grad_fn=<SigmoidBackward0>),
 tensor([[0.0187, 0.0402, 0.0429, 0.0156, 0.0741, 0.3673, 0.0907, 0.0041, 0.1246,
          0.2217]], grad_fn=<SoftmaxBackward0>))

In [49]:
import pandas as pd
# label_file_train = "/system/user/kimesweg/data/small/metadata_train_new.csv"
# label_file_test = "/system/user/kimesweg/data/small/metadata_test_new.csv"
# labels_test = pd.read_csv(label_file_test,index_col=0)
# labels_train = pd.read_csv(label_file_train,index_col=0)
# labels = pd.concat((labels_test, labels_train))
# Data_test = pd.read_csv(label_file_test)
# Data_test


Unnamed: 0.1,Unnamed: 0,1
0,133ff1855f,1
1,15b815a470,1
2,16dbf55509,1
3,16f7d9dbc2,1
4,181ec92105,1
...,...,...
127,11fffcc1f6,1
128,ICNNBCC00362,0
129,019f23d6b7,1
130,ICNNBCC00386,0
