In [3]:
import h5py
import numpy as np
import pandas as pd
import cv2
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as T

In [4]:
f = h5py.File("./data/elucidata_ai_challenge_data.h5")

train_images = f['images/Train'] # 6 images with id's like S_1,S_2 to S_6
test_image = f['images/Test'] # 1 image with id S_7

train_spots = f['spots/Train'] # this contains spot data for each 6 ids first 2 columns are x and y and column 3 to 37 are cell abundance
test_spot = f['spots/Test']

In [8]:
for key in train_images.keys():
    print(train_images[key])

<HDF5 dataset "S_1": shape (2000, 1974, 3), type "<f4">
<HDF5 dataset "S_2": shape (2000, 1988, 3), type "<f4">
<HDF5 dataset "S_3": shape (2000, 1966, 3), type "<f4">
<HDF5 dataset "S_4": shape (2000, 1979, 3), type "<f4">
<HDF5 dataset "S_5": shape (1985, 2000, 3), type "<f4">
<HDF5 dataset "S_6": shape (2000, 1930, 3), type "<f4">


In [26]:
def extract_patches(h5_file, patch_size=128, val_ratio=0.1):
    images = h5_file['images/Train']
    spots = h5_file['spots/Train']

    train_patches = []
    train_targets = []
    val_patches = []
    val_targets = []

    half = patch_size // 2
    for slide_id in images.keys():
        img = np.array(images[slide_id])  # shape: (H, W, 3)
        # print(img.shape)

        # Fix: handle structured array
        spot_data = pd.DataFrame(spots[slide_id][()])

        coords = spot_data.iloc[:, :2].astype(int).values  # x, y
        targets = spot_data.iloc[:, 2:].values  # shape: (num_spots, 35)

        num_spots = coords.shape[0]
        indices = np.arange(num_spots)
        np.random.shuffle(indices)

        split_idx = int((1 - val_ratio) * num_spots)
        train_idx, val_idx = indices[:split_idx], indices[split_idx:]

        # 以 spot 坐标为中心，截取一个边长为 patch 的正方形图片。如果因为 spot 太靠近边缘导致截取不出正方形就抛弃这个样本。
        for idx in train_idx:
            x, y = coords[idx]
            if x - half < 0 or y - half < 0 or x + half > img.shape[1] or y + half > img.shape[0]:
                continue
            patch = img[y - half:y + half, x - half:x + half]
            if patch.shape[:2] == (patch_size, patch_size):
                train_patches.append(patch)
                train_targets.append(targets[idx])
        for idx in val_idx:
            x, y = coords[idx]
            if x - half < 0 or y - half < 0 or x + half > img.shape[1] or y + half > img.shape[0]:
                continue
            patch = img[y - half:y + half, x - half:x + half]
            if patch.shape[:2] == (patch_size, patch_size):
                val_patches.append(patch)
                val_targets.append(targets[idx])

    return (
        np.array(train_patches), np.array(train_targets),
        np.array(val_patches), np.array(val_targets)
    )

In [27]:
train_patches, train_targets, val_patches, val_targets = extract_patches(f)

print("Train patches:", train_patches.shape)
print("Train targets:", train_targets.shape)
print("Val patches:", val_patches.shape)
print("Val targets:", val_targets.shape)

Train patches: (7513, 128, 128, 3)
Train targets: (7513, 35)
Val patches: (836, 128, 128, 3)
Val targets: (836, 35)


In [29]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class SpotPatchDataset(Dataset):
    def __init__(self, patches, targets):
        self.patches = patches
        self.targets = targets

        self.transform = transforms.Compose([
            transforms.ToTensor(),  # Convert to [0,1] and shape [C,H,W]
        ])

    def __len__(self):
        return len(self.patches)

    def __getitem__(self, idx):
        image = self.patches[idx]
        target = self.targets[idx]
        image = self.transform(image)
        target = torch.tensor(target, dtype=torch.float32)
        return image, target


In [30]:
from torch.utils.data import DataLoader

train_dataset = SpotPatchDataset(train_patches, train_targets)
val_dataset = SpotPatchDataset(val_patches, val_targets)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

In [31]:
import torch.nn as nn
from torchvision import models

class ResNetRegressor(nn.Module):
    def __init__(self, output_dim=35):
        super(ResNetRegressor, self).__init__()
        self.backbone = models.resnet18(pretrained=True)
        
        # Replace final FC layer
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, output_dim)

    def forward(self, x):
        return self.backbone(x)

In [32]:
model = ResNetRegressor(output_dim=35)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Loss: MSE for regression
criterion = nn.MSELoss()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)



Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\PilockHulmes/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:16<00:00, 2.77MB/s]


In [33]:
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0

    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    return running_loss / len(dataloader.dataset)

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0

    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            running_loss += loss.item() * inputs.size(0)

    return running_loss / len(dataloader.dataset)

In [34]:
num_epochs = 5

for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss = validate(model, val_loader, criterion, device)

    print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

Epoch 1/5 | Train Loss: 0.7452 | Val Loss: 0.4794
Epoch 2/5 | Train Loss: 0.4144 | Val Loss: 0.4325
Epoch 3/5 | Train Loss: 0.2760 | Val Loss: 0.3907
Epoch 4/5 | Train Loss: 0.2018 | Val Loss: 0.4086
Epoch 5/5 | Train Loss: 0.1717 | Val Loss: 0.4143


In [35]:
def extract_test_patches(h5_file, patch_size=128):
    test_img = np.array(h5_file['images/Test']['S_7'])
    test_spot_data = pd.DataFrame(np.array(h5_file['spots/Test']['S_7']))  # shape: (num_spots, 37)
    
    coords = test_spot_data.iloc[:, :2].astype(int).values  # x, y
    half = patch_size // 2

    patches = []
    valid_coords = []

    for coord in coords:
        x, y = coord
        if x - half < 0 or y - half < 0 or x + half > test_img.shape[1] or y + half > test_img.shape[0]:
            continue
        patch = test_img[y - half:y + half, x - half:x + half]
        if patch.shape[:2] == (patch_size, patch_size):
            patches.append(patch)
            valid_coords.append(coord)

    return np.array(patches), np.array(valid_coords)

In [36]:
test_patches, test_coords = extract_test_patches(f)
test_patches_tensor = torch.tensor(test_patches).permute(0, 3, 1, 2).float()   # (N, 3, 224, 224)

In [37]:
model.eval()
predictions = []

with torch.no_grad():
    for i in range(0, len(test_patches_tensor), 32):
        batch = test_patches_tensor[i:i+32].to(device)
        outputs = model(batch).cpu().numpy()
        predictions.append(outputs)

predictions = np.vstack(predictions)  # shape: (num_valid_spots, 35)

In [38]:
df_preds = pd.DataFrame(predictions, columns=[f'cell_type_{i+1}' for i in range(35)])
df_preds.insert(0, 'ID', range(len(df_preds)))

In [39]:
df_preds

Unnamed: 0,ID,cell_type_1,cell_type_2,cell_type_3,cell_type_4,cell_type_5,cell_type_6,cell_type_7,cell_type_8,cell_type_9,...,cell_type_26,cell_type_27,cell_type_28,cell_type_29,cell_type_30,cell_type_31,cell_type_32,cell_type_33,cell_type_34,cell_type_35
0,0,-0.174584,0.027349,-0.012746,0.082535,-0.046270,0.104565,0.085526,-0.055208,0.126084,...,0.006996,0.075627,-0.070324,0.072997,0.129147,0.021696,-0.024244,0.125135,-0.027928,0.000338
1,1,0.131282,0.161235,0.207157,-0.139339,0.741986,-0.073061,-0.050273,0.047330,0.052717,...,0.030605,-0.124524,-0.100443,0.130916,0.011940,0.093736,-0.025264,0.086049,0.061811,-0.011804
2,2,2.964469,0.155687,1.844579,1.257519,1.531961,0.041686,0.243872,-0.035342,0.029778,...,0.035841,0.066327,0.093249,0.089618,-0.089410,0.166974,0.200569,0.024088,0.011976,0.164817
3,3,2.477433,0.157950,1.775822,1.067452,1.162134,0.098915,0.172305,-0.066085,0.079294,...,0.044754,0.083666,0.097626,0.146031,-0.099052,0.214242,0.154818,0.085896,0.079986,0.194094
4,4,0.746700,0.100708,0.503298,0.623538,0.811936,0.159962,0.149131,0.107832,-0.126790,...,0.047254,-0.021375,-0.203817,0.077706,-0.109591,0.190179,0.001684,0.030271,-0.083986,0.104075
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2083,2083,1.314880,0.171256,0.614562,0.702651,2.039878,0.156576,0.113571,0.032434,0.294636,...,0.113106,0.161888,0.025886,0.246883,-0.021899,0.200136,0.184720,0.185973,-0.032186,0.168197
2084,2084,0.577924,-0.025309,0.464563,0.488105,0.778466,0.204367,0.066787,-0.063337,-0.016045,...,0.042081,0.089921,-0.107632,-0.018493,-0.064729,0.079831,-0.102010,0.129238,-0.030160,-0.030514
2085,2085,1.046885,0.065520,0.538273,0.617251,0.911059,0.003176,0.082960,0.132783,0.087148,...,-0.104870,0.137216,-0.195490,-0.108503,-0.005482,0.063600,-0.064646,-0.006508,0.110449,0.065237
2086,2086,0.144632,0.162647,0.167482,0.027493,0.422153,0.129378,0.059011,0.061215,-0.125368,...,0.031321,-0.196345,0.007433,0.137486,-0.016016,0.048250,-0.001363,0.216284,0.101883,-0.070533


In [40]:
df_preds.to_csv('submission.csv', index=False)