In [3]:
# Load packages and classes
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tiffslide
import seaborn as sns
import gget
import tifffile
import zarr

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# MosaicDataset and BruceDataset classes allow loading and visualisation of the different data sources
#from gbmhackathon import MosaicDataset, BruceDataset

In [16]:
import os
print(os.getcwd())

/home/ec2-user/SageMaker/gbm_hackathon/notebooks


In [18]:
scell = np.load('./scFoundation_sc_embedding/embedding.npy')
bulk = np.load('./Foundation_bulk/embedding_bulk.npy')

In [20]:
print(scell.shape)
print(bulk.shape)

(18614, 3072)
(104, 3072)


In [None]:
# 🔹 Projeção dos embeddings para um espaço comum
class ProjectionHead(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ProjectionHead, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        return F.normalize(self.fc(x), dim=-1)  # Normaliza para facilitar cálculo de similaridade

# 🔹 Definição do modelo
class ContrastiveModel(nn.Module):
    def __init__(self, dim_img, dim_bulk, dim_mut, output_dim):
        super(ContrastiveModel, self).__init__()
        self.proj_img = ProjectionHead(dim_img, output_dim)
        self.proj_bulk = ProjectionHead(dim_bulk, output_dim)
        self.proj_mut = ProjectionHead(dim_mut, output_dim)

    def forward(self, img, bulk, mut):
        z_img = self.proj_img(img)
        z_bulk = self.proj_bulk(bulk)
        z_mut = self.proj_mut(mut)
        return z_img, z_bulk, z_mut

# 🔹 Loss Contrastiva (InfoNCE Loss)
def contrastive_loss(z1, z2, temperature=0.1):
    batch_size = z1.shape[0]
    sim_matrix = torch.mm(z1, z2.T)  # Produto interno dos embeddings
    sim_matrix /= temperature
    labels = torch.arange(batch_size).to(z1.device)  # Índices como labels
    return F.cross_entropy(sim_matrix, labels)

# 🔹 Dados Fictícios (Substitua pelos seus embeddings reais)
N, d1, d2, d3 = 100, 512, 256, 128  # 100 pacientes, dimensões diferentes
emb_img = torch.randn(N, d1)
emb_bulk = torch.randn(N, d2)
emb_mut = torch.randn(N, d3)

# 🔹 Treinamento
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ContrastiveModel(d1, d2, d3, output_dim=64).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 🔹 Loop de Treinamento
epochs = 100
for epoch in range(epochs):
    optimizer.zero_grad()
    z_img, z_bulk, z_mut = model(emb_img.to(device), emb_bulk.to(device), emb_mut.to(device))
    
    loss = contrastive_loss(z_img, z_bulk) + contrastive_loss(z_img, z_mut) + contrastive_loss(z_bulk, z_mut)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
