In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Dataset
from torchvision.models import resnet18, ResNet18_Weights
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
import warnings
import pandas as pd
warnings.filterwarnings("ignore")
import re
from tqdm import tqdm
from PIL import Image

df = pd.read_csv('kaggle_3m\data.csv')
missing = df[df['age_at_initial_pathologic'].isna()]['Patient']
missing

  df = pd.read_csv('kaggle_3m\data.csv')


109    TCGA_HT_A61B
Name: Patient, dtype: object

In [2]:
def get_images(dir):
    images = []
    ids = []

    for subdir in os.listdir(dir):
        path = os.path.join(dir, subdir)
        pattern = r"TCGA_(CS|DU|FG|HT|EZ)_(\w{4})"
        match = re.search(pattern, path)

        if match:
            ids.append(match.group(2))

        if path.startswith("kaggle_3m\TCGA"):
            for image_name in os.listdir(path):
                
                if "mask" in image_name:
                    continue
                elif "TCGA_HT_A61B" in image_name:
                    continue
                else:
                    images.append(os.path.join(path, image_name))
    return images, ids

def get_labels(images, ids):
    ids = [id for image in images for id in ids if  id in image]
    df['num_id'] = df['Patient'].str.extract(r'([^_]+)$')
    labels = []
    for id in ids:
        age = df.loc[df['num_id'] == id, 'age_at_initial_pathologic'].values
        labels.append(age)
    return labels

class CustomImageDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
            
        if self.transform:
            image = self.transform(image)
        
        return image, label

def evaluate(model, loader, device):
    model.eval()
    total_error = 0
    count = 0
    with torch.no_grad():  
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images) 
            total_error += torch.sum(torch.abs(outputs - labels))
            count += images.size(0)
    
    return total_error / count


In [None]:
images, ids = get_images('kaggle_3m')
labels = get_labels(images, ids)
for i in range(len(images)):
    images[i] = cv2.imread(images[i])
    images[i] = np.array(images[i])

transform = transforms.Compose([
    transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),  
    transforms.ToTensor(),  
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])

dataset = CustomImageDataset(images, labels, transform=transform)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

size = len(loader.dataset)
train_size = int(0.7 * size)
val_size = int(0.15 * size)
test_size = size - train_size - val_size

train_set, val_set, test_set = random_split(loader.dataset, [train_size, val_size, test_size])
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
val_loader = DataLoader(val_set, batch_size=16, shuffle=False)
test_loader = DataLoader(test_set, batch_size=16, shuffle=False)

weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)
model.fc = nn.Linear(model.fc.in_features, 1)

device = torch.device("cpu")
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 10


for epoch in range(epochs):
    model.train() 
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()  
        outputs = model(images)  
        loss = criterion(outputs.float(), labels.float())

        loss.backward()  
        optimizer.step()  

    MAE = evaluate(model, val_loader, device)
    print(f"Epoch {epoch+1}, Mean Absolute Error: {MAE.item()}")

Epoch 1:  66%|██████▌   | 111/168 [03:59<02:28,  2.60s/it]

In [None]:
torch.save(model.state_dict(), "resnet.pth")

In [22]:
MAE = evaluate(model, test_loader, device)
print(f"Mean Absolute Error: {MAE.item()}")

tensor([[51.],
        [66.],
        [59.],
        [70.],
        [49.],
        [66.],
        [40.],
        [33.],
        [36.],
        [47.],
        [61.],
        [33.],
        [49.],
        [51.],
        [41.],
        [30.]], dtype=torch.float64)
tensor([[51.4234],
        [69.2569],
        [58.6057],
        [70.9119],
        [50.0224],
        [67.7363],
        [40.9583],
        [40.9038],
        [36.0853],
        [42.9509],
        [64.5266],
        [41.1832],
        [46.4910],
        [50.3956],
        [42.7205],
        [31.5594]])
tensor([[48.],
        [30.],
        [43.],
        [58.],
        [70.],
        [39.],
        [28.],
        [49.],
        [34.],
        [29.],
        [34.],
        [48.],
        [29.],
        [32.],
        [57.],
        [30.]], dtype=torch.float64)
tensor([[58.1993],
        [32.7801],
        [41.6715],
        [59.3292],
        [67.7300],
        [48.4527],
        [25.9687],
        [50.3198],
        [34.0857],
