In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets, models
import os
import numpy as np
from tqdm import tqdm


In [2]:
from scripts.SSL import MoCoLightning

Cargando Dataset
Dataset Cargado




In [3]:
resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
backbone = nn.Sequential(*list(resnet.children())[:-1])  # Quitar la capa final

encoder = MoCoLightning(
        backbone=backbone,
        lr=0.0003,          # El LR que tenías
        temperature=0.1,    # La temperatura que tenías
        queue_size=8192     # Un valor más pequeño si 65536 da OOM
    )


state_dict = torch.load("/lustre/home/atorres/MEDA_Challenge/models/221025MG_backbone.ssl.pth", map_location='cuda')

# Como guardaste solo encoder_q[0], necesitas asignarlo a esa parte del modelo
encoder.encoder_q[0].load_state_dict(state_dict)

encoder = encoder.cuda() 


In [4]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

In [5]:
class MedMNISTUnifiedFolder(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.files = [os.path.join(root, f) for f in os.listdir(root)
                      if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = self.files[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img

In [6]:
transform = T.Compose([
    T.Resize((28,28)),
    T.ToTensor(),
])

dataset = MedMNISTUnifiedFolder("/lustre/home/atorres/compartido/datasets/all_medmnist_images", transform)
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2)

In [7]:
def colorization_pair(img):
    gray = T.Grayscale()(img)
    return gray, img


In [19]:
# Jigsaw: simplificación, target = primer índice de permutación
def jigsaw_pair(imgs, n=3):
    B, C, H, W = imgs.shape
    shuffled_imgs = []
    orders = []
    for i in range(B):
        patches = []
        for y in range(n):
            for x in range(n):
                patch_h, patch_w = H // n, W // n
                patch = imgs[i, :, y*patch_h:(y+1)*patch_h, x*patch_w:(x+1)*patch_w]
                patches.append(patch)
        order = torch.randperm(len(patches))
        shuffled = torch.zeros_like(imgs[i])
        for k, idx in enumerate(order):
            yk, xk = divmod(k, n)
            shuffled[:, yk*patch_h:(yk+1)*patch_h, xk*patch_w:(xk+1)*patch_w] = patches[idx]
        shuffled_imgs.append(shuffled)
        orders.append(order[0])  # solo el primer índice para simplificar
    return torch.stack(shuffled_imgs), torch.tensor(orders)

In [17]:
from PIL import Image, ImageDraw
import torchvision.transforms.functional as TF

def patch_prediction_pair(imgs, mask_size=16):
    # imgs: tensor BxCxHxW
    masked_imgs = []
    target_imgs = []

    for img in imgs:
        # Convertir tensor a PIL
        pil_img = TF.to_pil_image(img.cpu())
        w, h = pil_img.size
        x = (w - mask_size) // 2
        y = (h - mask_size) // 2

        masked = pil_img.copy()
        draw = ImageDraw.Draw(masked)
        draw.rectangle([x, y, x+mask_size, y+mask_size], fill=(0,0,0))

        # Convertir de vuelta a tensor
        masked_tensor = TF.to_tensor(masked).to(img.device)
        target_tensor = img

        masked_imgs.append(masked_tensor)
        target_imgs.append(target_tensor)

    return torch.stack(masked_imgs), torch.stack(target_imgs)


In [10]:
import torch.nn as nn

class MultiPretextSSL(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone  # tu encoder SSL

        # Cada tarea tiene su propia "head"
        self.color_head = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),
            nn.Sigmoid()
        )

        self.jigsaw_head = nn.Linear(512, 9 * 9)  # predicción del orden 9x9
        self.patch_head = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),
            nn.Sigmoid()
        )

    def forward(self, x, task="color"):
        feats = self.backbone(x).squeeze()
        if task == "color":
            return self.color_head(feats.unsqueeze(-1).unsqueeze(-1))
        elif task == "patch":
            return self.patch_head(feats.unsqueeze(-1).unsqueeze(-1))
        elif task == "jigsaw":
            return self.jigsaw_head(feats)


In [None]:
# --- Bucle de entrenamiento multitask ---
for epoch in range(10):
    for imgs in loader:
        imgs = imgs.cuda()
        task = random.choice(["color", "patch", "jigsaw"])

        if task == "color":
            gray, target = colorization_pair(imgs)
            gray, target = gray.cuda(), target.cuda()
            pred = encoder(gray, "color")
            loss = F.mse_loss(pred, target)
        
        elif task == "patch":
            masked, target = patch_prediction_pair(imgs)
            masked, target = masked.cuda(), target.cuda()
            pred = encoder(masked, "patch")
            loss = F.mse_loss(pred, target)
        
        else:  # jigsaw
            shuffled, target_order = jigsaw_pair(imgs)
            shuffled, target_order = shuffled.cuda(), target_order.cuda()
            pred = encoder(shuffled, "jigsaw")
            loss = F.cross_entropy(pred, target_order)

        opt.zero_grad()
        loss.backward()
        opt.step()

    print(f"Epoch {epoch+1} | Task: {task} | Loss: {loss.item():.4f}")




TypeError: forward() takes 2 positional arguments but 3 were given

: 