# Cycle-GAN for CT To MRI Image Translation

## Overview
This project uses the Cycle-GAN Generative Adversarial Network (GAN) to translate CT images into MRI scans. The model is trained on paired MRI and CT images, learning to generate realistic MRI scans from CT inputs. This approach can be useful in medical imaging by reducing the need for multiple scans and enhancing diagnostic workflows.

## Model Architecture
Cycle-GAN consists of two primary components:

- **Generator:** A U-Net-based architecture that takes an MRI image as input and generates a corresponding CT image.
- **Discriminator:** A PatchGAN-based model that determines whether an image is a real CT scan or a generated one.

The generator learns to fool the discriminator, while the discriminator improves its ability to distinguish real from fake images, leading to high-quality image translations.

## Dataset
The dataset consists of paired MRI and CT images:

- `trainA`: CT images
- `trainB`: Corresponding MRI images

During training, the model takes an CT image from `trainA` and attempts to generate a realistic MRI image that matches the corresponding image from `trainB`.

## Training Details

### Loss Functions
- **Adversarial Loss (BCE Loss):** Encourages the generator to produce realistic MRI images.
- **L1 Loss:** Ensures structural similarity between the generated and real MRI images.

### Optimization
- Adam optimizer with a learning rate of `0.0002` and betas `(0.5, 0.999)`.

### Data Augmentation
- Images are resized to `256x256` and normalized.


## Usage
Trained models are saved as:

- `generator.pth`
- `discriminator.pth`

## Dependencies
Ensure the following dependencies are installed:

```bash
pip install torch torchvision pillow numpy scikit-image matplotlib
```

## Conclusion
This Cycle-GAN-based model effectively converts CT images into MRI scans, reducing the need for multiple imaging modalities and enhancing medical imaging workflows. Further improvements can be made by training on larger datasets or incorporating attention mechanisms for better detail preservation.

In [None]:
import numpy as np 
import pandas as pd 

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
import numpy as np


In [None]:

class cycleganDataset(Dataset):
    def __init__(self, trainA_path, trainB_path, transform=None):
        self.trainA_images = sorted([os.path.join(trainA_path, img) for img in os.listdir(trainA_path)])
        self.trainB_images = sorted([os.path.join(trainB_path, img) for img in os.listdir(trainB_path)])
        self.transform = transform

    def __len__(self):
        return min(len(self.trainA_images), len(self.trainB_images))

    def __getitem__(self, idx):
        img_A = Image.open(self.trainA_images[idx]).convert('RGB')
        img_B = Image.open(self.trainB_images[idx]).convert('RGB')
        if self.transform:
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)
        return img_A, img_B

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, 1, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, 4, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)
        return self.model(x)

In [None]:

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


In [None]:
def save_generated_images(real_A, real_B, fake_B, epoch, save_dir):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    real_A = real_A.cpu().numpy()
    real_B = real_B.cpu().numpy()
    fake_B = fake_B.cpu().numpy()

    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.title('CT Scan')
    plt.imshow(np.transpose(real_A[0], (1, 2, 0)) * 0.5 + 0.5)
    plt.subplot(1, 3, 2)
    plt.title('Ground Truth')
    plt.imshow(np.transpose(real_B[0], (1, 2, 0)) * 0.5 + 0.5)
    plt.subplot(1, 3, 3)
    plt.title('Generated Image')
    plt.imshow(np.transpose(fake_B[0], (1, 2, 0)) * 0.5 + 0.5)
    plt.savefig(os.path.join(save_dir, f'epoch_{epoch}.png'))
    plt.close()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = cycleganDataset('/kaggle/input/ct-to-mri-cgan/Dataset/images/trainA', '/kaggle/input/ct-to-mri-cgan/Dataset/images/trainB', transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

generator = Generator().to(device)
discriminator = Discriminator().to(device)
generator.apply(weights_init)
discriminator.apply(weights_init)

criterion = nn.BCELoss()
l1_loss = nn.L1Loss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

save_dir = '/kaggle/working/images'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

for epoch in range(1):
    for i, (real_A, real_B) in enumerate(dataloader):
        real_A, real_B = real_A.to(device), real_B.to(device)
        fake_B = generator(real_A)

        optimizer_D.zero_grad()
        real_labels = torch.ones(real_A.size(0), 1, 30, 30).to(device)
        fake_labels = torch.zeros(real_A.size(0), 1, 30, 30).to(device)
        real_loss = criterion(discriminator(real_B, real_A), real_labels)
        fake_loss = criterion(discriminator(fake_B.detach(), real_A), fake_labels)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        optimizer_G.zero_grad()
        g_loss = criterion(discriminator(fake_B, real_A), real_labels)
        l1 = l1_loss(fake_B, real_B)
        g_total_loss = g_loss + 100 * l1
        g_total_loss.backward()
        optimizer_G.step()

        if i % 100 == 0:
            print(f'Epoch [{epoch}/{500}], Step [{i}/{len(dataloader)}], '
                  f'D_loss: {d_loss.item():.4f}, G_loss: {g_total_loss.item():.4f}')

    if epoch % 100 == 0:
        with torch.no_grad():
            fake_B = generator(real_A)
            fake_B = fake_B.detach().cpu()
            real_A = real_A.cpu()
            real_B = real_B.cpu()

            save_generated_images(real_A, real_B, fake_B, epoch, save_dir)

torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')
print("Models saved.")