In [111]:
import torch
import pandas as pd
import torch.nn as nn
import pytorch_lightning as pl
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


class DINO_4k_feature_dataset(Dataset):
    def __init__(self, target):
        features = pd.read_csv('/home/rens/repos/premium_pathology/hipt_feature_extractor/data/features.csv').set_index('Unnamed: 0')
        labels = pd.read_csv('/home/rens/repos/premium_pathology/hipt_feature_extractor/data/labels.csv').set_index('Unnamed: 0')
        labels.index = [Path(ix).stem for ix in labels.index]

        self.target = target

        self.data = features.join(labels[target], on='slide')
        self.data[target] = self.data[target].astype(int)

        self.slides = self.data.slide.unique()
        
        self.transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.float())
        ])

    def __len__(self):
        return len(self.slides)
    
    def __getitem__(self, ix):
        slide = self.slides[ix]

        slide_data = self.data[self.data.slide == slide]

        features = slide_data[[f'feature{f}' for f in range(192)]].to_numpy()
        x = self.transforms(features)

        y = slide_data[self.target][0]

        return x, y
        
class GatedAttentionLayer(nn.Module):
    def __init__(self, input_dim, output_dim, dropout=0):
        super().__init__()
        
        self.attention_a = [
            nn.Linear(input_dim,output_dim),
            nn.Tanh()
        ]
        self.attention_b = [
            nn.Linear(input_dim,output_dim),
            nn.Sigmoid()
        ]

        self.attention_c = [
            nn.Linear(output_dim, 1),
            nn.Softmax(dim=1)
        ]

        if dropout:
            self.attention_a.append(nn.Dropout(dropout))
            self.attention_b.append(nn.Dropout(dropout))

        self.attention_a = nn.Sequential(*self.attention_a)
        self.attention_b = nn.Sequential(*self.attention_b)
        self.attention_c = nn.Sequential(*self.attention_c)

    def forward(self, x):
        a = self.attention_a(x)
        b = self.attention_b(x)
        A = a.mul(b)
        out = self.attention_c(A)

        return out
    
class AttentionModel(pl.LightningModule):
    def __init__(self, dropout=0):
        super().__init__()

        self.attention_layer = GatedAttentionLayer(192, 128, dropout)
        self.classifier = nn.Sequential(*[
            nn.Linear(192, 1),
            nn.Sigmoid()
        ])

    def forward(self, x):
        A = self.attention_layer(x).transpose(1,2)
        M = torch.matmul(A, x)
        out = self.classifier(M)

        return out


In [113]:
ds = DINO_4k_feature_dataset('primary')
x, y = ds[0]

model = AttentionModel()

model(x)

tensor([[[0.5584]]], grad_fn=<SigmoidBackward0>)

In [114]:
y

0

In [109]:
torch.tensor(0)

tensor(0)

In [89]:
dl = DataLoader(ds, 1, shuffle=False)

for x in dl:
    out = model(x)
    break


RuntimeError: expected scalar type Double but found Float

In [81]:
x = x.float()
A = attention(x).transpose(1,2)

M = torch.matmul(A,x)

classifier = 

out = classifier(M)


In [82]:
out

tensor([[[0.5910]]], grad_fn=<SigmoidBackward0>)

In [62]:
x.shape

torch.Size([1, 81, 192])