In [35]:
# %%
import os
import gc
import torch
import random
import pytorch_lightning as pl
import pandas as pd
import numpy as np
import cv2
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch_geometric.nn import GCNConv, BatchNorm
from torch_geometric.nn.models import GAE
from torch.utils.data import DataLoader
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping


class FeatureExtractor(nn.Module):
    """Some Information about FeatureExtractor"""
    def __init__(self, backbone='resnet50'):
        super(FeatureExtractor, self).__init__()
        self.backbone = torchvision.models.resnet50(pretrained=True)
        self.backbone.fc = nn.Identity()
    def forward(self, x):
        x = self.backbone(x)
        return x

class GNN(nn.Module):
    def __init__(self, in_channels=512, out_channels=512, ):
        super().__init__()
        self.conv = nn.ModuleList([GCNConv(in_channels, out_channels) for i in range(3)])
        self.lstm = nn.LSTM(in_channels, out_channels, 2)

    def forward(self, x, edge_index, edge_weights):
        jk = [gcn(x, edge_index, edge_weights).unsqueeze(0) for gcn in self.conv]
        x = torch.cat(jk,0)
        x, _ = self.lstm(x)
        x = x.mean(0)
        return x

class Autoencoder(nn.Module):
    # Auto encoder
    def __init__(self, hidden_dim=512, input_dim=2048):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, input_dim)
    def forward(self, x):
        h = self.encoder(x)
        h = nn.Dropout(0.2)(h)
        out = self.decoder(h)
        return out, h, F.mse_loss(x, out)
    
class CNN_AE_GNN(pl.LightningModule):
    def __init__(self, n_genes=1630, hidden_dim=512, learning_rate=1e-4,):
        super().__init__()
        self.save_hyperparameters()
        self.feature_extractor = FeatureExtractor()
        for param in self.feature_extractor.parameters():
            param.requires_grad = True
        self.AE = Autoencoder(hidden_dim=512, input_dim=2048)
        self.GNN = GNN(in_channels=hidden_dim, out_channels=hidden_dim)
        self.pred_head = nn.Linear(hidden_dim, n_genes)
        self.learning_rate = learning_rate
        self.n_genes = n_genes

    def forward(self, patch, edge_index, edge_weights):
        x = self.feature_extractor(patch)
        recon, h, recon_loss  = self.AE(x)
        h = self.GNN(h, edge_index, edge_weights)
        pred = self.pred_head(F.relu(h))
        return pred, h, recon_loss

    def training_step(self, batch, batch_idx):
        patch, _, exp, *_, edge_index, edge_weights = batch
        patch = patch.squeeze(0)
        pred, _, recon_loss = self(patch, edge_index, edge_weights)
        loss = F.mse_loss(pred, exp) + recon_loss
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        patch, _, exp, *_, edge_index, edge_weights = batch
        patch = patch.squeeze(0)
        pred, _, recon_loss = self(patch, edge_index, edge_weights)
        loss = F.mse_loss(pred, exp) + recon_loss
        self.log('val_loss', loss)
        return loss
        
    def test_step(self, batch, batch_idx):
        patch, _, exp, *_, edge_index, edge_weights = batch
        patch = patch.squeeze(0)
        pred, _, recon_loss = self(patch, edge_index, edge_weights)
        loss = F.mse_loss(pred, exp) + recon_loss
        self.log('test_loss', loss)
        return loss

    def predict_step(self, batch, batch_idx):
        patch, _, exp, *_, edge_index, edge_weights = batch
        patch = patch.squeeze(0)
        pred, *_ = self(patch, edge_index, edge_weights)
        pred = pred.squeeze(0).cpu().numpy()
        exp = exp.squeeze(0).cpu().numpy()
        return  pred, exp
    
    def configure_optimizers(self):
        # self.hparams available because we called self.save_hyperparameters()
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
