In [None]:
%pip install pydicom numpy pillow torch --quiet

In [None]:
import os
import numpy as np
import pydicom
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [None]:
class QuickTestDataset(Dataset):
    def __init__(self):
        self.pet = [np.random.rand(256,256) for _ in range(5)]
        self.ct = [np.random.rand(256,256) for _ in range(5)]
        
    def __len__(self):
        return 5
    
    def __getitem__(self, idx):
        return torch.tensor(self.pet[idx][None]), torch.tensor(self.ct[idx][None])

In [None]:
class MiniGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 1, 4, stride=2),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.encoder(x)
        return self.decoder(x)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G = MiniGenerator().to(device)
opt = torch.optim.Adam(G.parameters(), lr=1e-4)
criterion = nn.L1Loss()

In [None]:
# 4. Fast Training Loop
# ==============================
def quick_train():
    dataset = QuickTestDataset()
    loader = DataLoader(dataset, batch_size=2, shuffle=True)
    
    for epoch in range(3):  # 3 epochs only
        for pet, ct in loader:
            pet, ct = pet.float().to(device), ct.float().to(device)
            opt.zero_grad()
            fake_ct = G(pet)
            loss = criterion(fake_ct, ct)
            loss.backward()
            opt.step()
        print(f"Epoch {epoch+1} Loss: {loss.item():.4f}")

In [None]:
# 5. Run & Verify
# ==============================
if __name__ == "__main__":
    quick_train()
    print("\nâœ… Quick test successful! Now run full notebook.")
    
    # Sample output visualization
    test_input = torch.randn(1, 1, 256, 256).to(device)
    with torch.no_grad():
        output = G(test_input).cpu().numpy()
    
    plt.figure(figsize=(10,5))
    plt.subplot(121).imshow(test_input[0,0].cpu(), cmap='gray')
    plt.title('Input PET')
    plt.subplot(122).imshow(output[0,0], cmap='gray')
    plt.title('Generated CT')
    plt.show()