In [None]:
!mkdir -p ~/.kaggle
!cp /kaggle/input/apikaggle/kaggle.json ~/.kaggle/

!kaggle competitions download -c rsna-pneumonia-detection-challenge

## Unzip Dataset

In [None]:
#unzip the file
import zipfile
zip_ref = zipfile.ZipFile('/kaggle/working/rsna-pneumonia-detection-challenge.zip', 'r')
zip_ref.extractall('/kaggle/working/rsna-pneumonia-detection-dataset')
zip_ref.close()

In [1]:
!pip install pydicom wandb torchsummary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl.metadata (296 bytes)
Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1


In [2]:
pip install --upgrade torchmetrics lightning==2.2.3

Collecting torchmetrics
  Downloading torchmetrics-1.4.0-py3-none-any.whl.metadata (19 kB)
Collecting lightning==2.2.3
  Downloading lightning-2.2.3-py3-none-any.whl.metadata (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.4/53.4 kB[0m [31m918.3 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting pretty-errors==1.2.25 (from torchmetrics)
  Downloading pretty_errors-1.2.25-py3-none-any.whl.metadata (12 kB)
Downloading lightning-2.2.3-py3-none-any.whl (2.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading torchmetrics-1.4.0-py3-none-any.whl (868 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m868.8/868.8 kB[0m [31m45.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pretty_errors-1.2.25-py3-none-any.whl (17 kB)
Installing collected packages: pretty-errors, torchmetrics, lightning
  Attempting uninstall: torchmetrics
    Found ex

In [3]:
import os
import shutil
import pandas as pd
import pydicom
from PIL import Image
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models

In [4]:
import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("WANDB_KEY")

wandb.login(key=secret_value_0)


[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
import os
import shutil
import pandas as pd
import pydicom
from PIL import Image
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor

class ImageConverter:
    def __init__(self, input_folder, output_folder, csv_path, image_size=512):
        self.input_folder = input_folder
        self.output_folder = output_folder
        self.csv_path = csv_path
        self.image_size = image_size

    def _convert_single_dcm_to_png(self, file_name, label):
        # Check if the DICOM file exists
        dicom_path = os.path.join(self.input_folder, file_name + '.dcm',)
        if not os.path.exists(dicom_path):
            print(f"Warning: DICOM file not found for {dicom_path}")
            return

        # Read DICOM file
        dicom_data = pydicom.dcmread(dicom_path)

        # Convert DICOM to PNG
        image_array = dicom_data.pixel_array
        image = Image.fromarray(image_array)
        image = image.resize((self.image_size, self.image_size))

        # Define the output path based on train/test/val and label
        if self.index % 5 == 0:  # 20% for validation
            output_path = os.path.join(self.output_folder, f'val/{label}/{file_name[:-4]}.png')
        elif self.index % 5 == 1:  # 20% for test
            output_path = os.path.join(self.output_folder, f'test/{label}/{file_name[:-4]}.png')
        else:  # 60% for train
            output_path = os.path.join(self.output_folder, f'train/{label}/{file_name[:-4]}.png')

        # Save the image
        image.save(output_path)

    def _convert_dcm_to_png_for_index(self, index_row):
        self.index, row = index_row
        file_name = row['patientId']
        label = row['Target']
        self._convert_single_dcm_to_png(file_name, label)

    def convert_dcm_to_png_parallel(self):
        # Create output folders if they don't exist
        for folder in ['train/0/', 'train/1/', 'test/0/', 'test/1/', 'val/0/', 'val/1/']:
            os.makedirs(os.path.join(self.output_folder, folder), exist_ok=True)

        # Read CSV file
        df = pd.read_csv(self.csv_path)
        print('Total files: ', df.shape[0])

        # Use ProcessPoolExecutor for parallel processing
        with ProcessPoolExecutor() as executor:
            list(tqdm(executor.map(self._convert_dcm_to_png_for_index, df.iterrows()), total=len(df), desc="Converting images"))
# if __name__ == "__main__":
input_folder = "/kaggle/working/rsna-pneumonia-detection-dataset/stage_2_train_images"
output_folder = "/kaggle/working/CycleGan-CFE/train-data"
csv_path = "/kaggle/working/rsna-pneumonia-detection-dataset/stage_2_train_labels.csv"

image_converter = ImageConverter(input_folder, output_folder, csv_path)
image_converter.convert_dcm_to_png_parallel()


In [5]:
class ClassifierDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        self.classes = ['0', '1']
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        self.samples = self._make_dataset()

    def _make_dataset(self):
        samples = []
        for class_name in self.classes:
            class_dir = os.path.join(self.root_dir, class_name)
            for img_name in os.listdir(class_dir):
                img_path = os.path.join(class_dir, img_name)
                samples.append((img_path, self.class_to_idx[class_name]))
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert('L')  # Convert to grayscale
        if self.transform:
            img = self.transform(img)
        return img, label

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import lightning as pl
import wandb
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.tuner import Tuner
import tqdm.auto as tqdm
from torchmetrics import Accuracy

class Classifier(pl.LightningModule):
    def __init__(self, transfer=True):
        super(Classifier, self).__init__()
        self.conv = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1)  # Adjust input channels to 3
        self.model = models.swin_t(weights='IMAGENET1K_V1')
        if transfer:
            # layers are frozen by using eval()
            self.model.eval()
            # freeze params
            for p in self.model.parameters() : 
                p.requires_grad = False
        num_ftrs = 768
        self.model.head = nn.Sequential(
            nn.Linear(in_features=num_ftrs, out_features=256),
            nn.LeakyReLU(),
            nn.Dropout(p=0.5), 
            nn.Linear(in_features=256 , out_features=2),
            nn.Softmax(dim=1)  
        ) 

        self.criterion = nn.CrossEntropyLoss()
        self.train_accuracy = Accuracy(task='binary')
        self.val_accuracy = Accuracy(task='binary')

    def forward(self, x):
        x = self.conv(x)
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        self.log('train_loss', loss)
        # Calculate and log accuracy
        _, preds = torch.max(outputs, 1)
        acc = self.train_accuracy(preds, labels)
        self.log('train_acc', acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        self.log('val_loss', loss, prog_bar=True, sync_dist=True)
        # Calculate and log accuracy
        _, preds = torch.max(outputs, 1)
        acc = self.val_accuracy(preds, labels)
        self.log('val_acc', acc, prog_bar=True, sync_dist=True)
        return loss
    
    def on_train_epoch_end(self):
        self.train_accuracy.reset()

    def on_validation_epoch_end(self):
        self.val_accuracy.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=5, verbose=True)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss',
            },
            'monitor': 'val_loss'
        }

wandb_logger = WandbLogger(project="CycleGAN-CFE", name="swin_t-classifier-training")
# Define data transformations
IMAGE_SIZE = 512
BATCH_SIZE = 16
EPOCHS = 20

transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),  # Resize image to 512x512
    transforms.ToTensor(),          
    transforms.Normalize(mean=[0.485], std=[0.229])  # Normalize image
])

# Define dataset paths
train_dir = "/kaggle/working/CycleGan-CFE/train-data/train"
val_dir = "/kaggle/working/CycleGan-CFE/train-data/val"

# Create datasets
train_dataset = ClassifierDataset(root_dir=train_dir, transform=transform)
val_dataset = ClassifierDataset(root_dir=val_dir, transform=transform)
print("Total Training Images: ",len(train_dataset))
print("Total Validation Images: ",len(val_dataset))

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=4)

Total Training Images:  16739
Total Validation Images:  6046


In [None]:
# Instantiate the discriminator model
clf = Classifier(transfer=True)

checkpoint_callback = ModelCheckpoint(
     monitor='val_loss',
     dirpath='/kaggle/working/CycleGan-CFE/models/',
     filename='swin_t-epoch{epoch:02d}-val_loss{val_loss:.2f}',
     auto_insert_metric_name=False,
 )
# Set up PyTorch Lightning Trainer with multiple GPUs and tqdm progress bar
trainer = pl.Trainer(
    devices=2,
    accelerator="gpu",
    max_epochs=EPOCHS,
    accumulate_grad_batches=10,
    log_every_n_steps=1,
    check_val_every_n_epoch=1,
    benchmark=True,
    logger=wandb_logger,
    callbacks=[checkpoint_callback],
)


# Train the discriminator
trainer.fit(clf, train_loader, val_loader)
wandb.finish()
# if __name__ == "__main__":
#     main()

In [7]:
image_size = 512

In [8]:
import os
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, root_dir, train_N, train_P, img_res=(128, 128)):
        self.root_dir = root_dir
        self.train_N = train_N
        self.train_P = train_P
        self.img_res = img_res
        self.transforms = transforms.Compose([
            transforms.Resize(img_res),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])  # Assuming grayscale images
        ])

    def __len__(self):
        return min(len(os.listdir(os.path.join(self.root_dir, self.train_N))),
                   len(os.listdir(os.path.join(self.root_dir, self.train_P))))

    def __getitem__(self, idx):
        normal_path = os.path.join(self.root_dir, self.train_N, os.listdir(os.path.join(self.root_dir, self.train_N))[idx])
        pneumo_path = os.path.join(self.root_dir, self.train_P, os.listdir(os.path.join(self.root_dir, self.train_P))[idx])
        
        normal_img = Image.open(normal_path).convert("L")  # Load as grayscale
        pneumo_img = Image.open(pneumo_path).convert("L")  # Load as grayscale
        
        normal_img = self.transforms(normal_img)
        pneumo_img = self.transforms(pneumo_img)
        
        return normal_img, pneumo_img




In [9]:
import torch
import torch.nn as nn
import lightning as pl
import wandb

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.0):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.relu(out)
        return out

class ResUNetGenerator(pl.LightningModule):
    def __init__(self, gf, channels, dropout_rate=0.3):
        super(ResUNetGenerator, self).__init__()
        self.gf = gf
        self.channels = channels
        self.dropout_rate = dropout_rate

        # Define the layers for the encoder
        self.conv2d = nn.Conv2d(channels, gf, kernel_size=4, stride=2, padding=1, padding_mode='reflect')
        
        self.conv2d_layers_left = nn.ModuleList([
            nn.Conv2d(gf * 2**i, gf * 2**(i+1), kernel_size=4, stride=2, padding=1, padding_mode='reflect')
            for i in range(4)
        ])
        
        self.conv2d_layers_right = nn.ModuleList([
            nn.Conv2d(gf * 2**(i+1), gf * 2**i, kernel_size=3, stride=1, padding=1, padding_mode='reflect')
            for i in range(4)
        ])
        
        self.groupNorm_layers = nn.ModuleList([
            nn.GroupNorm(8, gf * 2**(i+1))
            for i in range(4)
        ])
        
        self.res_blocks_left = nn.ModuleList([
            ResidualBlock(gf * 2**(i+1), gf * 2**(i+1), dropout_rate)
            for i in range(4)
        ])
        
        self.res_blocks_right = nn.ModuleList([
            ResidualBlock(gf * 2**i, gf * 2**i, dropout_rate)
            for i in range(4)
        ])

        # Define the layers for the decoder
        self.deconv2d_layers = nn.ModuleList([
            nn.ConvTranspose2d(gf * 2**(4-i), gf * 2**(3-i), kernel_size=4, stride=2, padding=1)
            for i in range(4)
        ])
        self.deconv2d_final = nn.ConvTranspose2d(gf, channels, kernel_size=4, stride=2, padding=1)
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.group_norm = nn.GroupNorm(8, gf)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        
        d0 = self.leaky_relu(self.group_norm(self.conv2d(x)))
        d1 = self.leaky_relu(self.groupNorm_layers[0](self.conv2d_layers_left[0](d0)))
        d1 = self.res_blocks_left[0](d1)
        
        d2 = self.leaky_relu(self.groupNorm_layers[1](self.conv2d_layers_left[1](d1)))
        d2 = self.res_blocks_left[1](d2)
        
        d3 = self.leaky_relu(self.groupNorm_layers[2](self.conv2d_layers_left[2](d2)))
        d3 = self.res_blocks_left[2](d3)
        
        d4 = self.leaky_relu(self.groupNorm_layers[3](self.conv2d_layers_left[3](d3)))
        d4 = self.res_blocks_left[3](d4)


        # Decoder
        u1 = self.deconv2d_layers[0](d4)
        u1 = torch.cat((u1, d3), dim=1)
        u1 = self.leaky_relu(self.groupNorm_layers[2](self.conv2d_layers_right[3](u1)))
        u1 = self.res_blocks_right[3](u1)
        
        u2 = self.deconv2d_layers[1](u1)
        u2 = torch.cat((u2, d2), dim=1)
        u2 = self.leaky_relu(self.groupNorm_layers[1](self.conv2d_layers_right[2](u2)))
        u2 = self.res_blocks_right[2](u2)

        u3 = self.deconv2d_layers[2](u2)
        u3 = torch.cat((u3, d1), dim=1)
        u3 = self.leaky_relu(self.groupNorm_layers[0](self.conv2d_layers_right[1](u3)))
        u3 = self.res_blocks_right[1](u3)
        
        u4 = self.deconv2d_layers[3](u3)
        u4 = torch.cat((u4, d0), dim=1)
        u4 = self.leaky_relu(self.group_norm(self.conv2d_layers_right[0](u4)))
        u4 = self.res_blocks_right[0](u4)

        output_img = self.sig(self.deconv2d_final(u4))

        return output_img

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0002, betas=(0.5, 0.999))
        return optimizer

In [10]:
class Discriminator(pl.LightningModule):
    def __init__(self, df):
        super(Discriminator, self).__init__()
        self.df = df
        # Define the layers for the discriminator
        self.conv_layers = nn.ModuleList([nn.Sequential(
            nn.Conv2d(1 if i == 0 else df * 2**(i-1), df * 2**i, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.GroupNorm(8, df * 2**i)) for i in range(4)])
        
        self.final_conv = nn.Conv2d(df * 8, 1, kernel_size=4, stride=1, padding=1)

    def forward(self, x):
        out = x
        for conv_layer in self.conv_layers:
            out = conv_layer(out)
        validity = self.final_conv(out)
        return validity

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0002, betas=(0.5, 0.999))
        return optimizer


In [None]:
from torchsummary import summary
model = Discriminator(df=64)
# Summarize the model architecture
summary(model, input_size=(1, 512, 512))

In [None]:
from torchsummary import summary
model = ResUNetGenerator(gf=32, channels=1)
# Summarize the model architecture
summary(model, input_size=(1, 512, 512))

In [11]:
batch_size = 4

In [13]:
import numpy as np
import matplotlib.pyplot as plt

class CycleGAN(pl.LightningModule):
    def __init__(self, img_shape=(1, 512, 512), gf=32, df=64, lambda_cycle=10.0, lambda_id=0.1, classifier_path=None, classifier_weight=None):
        super(CycleGAN, self).__init__()
        self.img_shape = img_shape
        self.gf = gf
        self.df = df
        self.lambda_cycle = lambda_cycle
        self.lambda_id = lambda_id * lambda_cycle
        self.classifier_path = classifier_path
        self.classifier_weight = classifier_weight

        # Initialize the generator, discriminator, and classifier models
        self.g_NP = ResUNetGenerator(gf, channels=self.img_shape[0])
        self.g_PN = ResUNetGenerator(gf, channels=self.img_shape[0])
        self.d_N = Discriminator(df)
        self.d_P = Discriminator(df)
        self.automatic_optimization = False
        
        self.classifier = Classifier()
        checkpoint = torch.load(classifier_path)
        self.classifier.load_state_dict(checkpoint['state_dict'])
        self.classifier.eval()
        self.freeze_classifier()
    
    def freeze_classifier(self):
        print("freezing Classifier...")
        for p in self.classifier.parameters() : 
                p.requires_grad = False


    def generator_training_step(self, img_N, img_P, opt):
        self.toggle_optimizer(opt)
        # Translate images to the other domain
        fake_P = self.g_NP(img_N)
        fake_N = self.g_PN(img_P)

        # Translate images back to original domain
        reconstr_N = self.g_PN(fake_P)
        reconstr_P = self.g_NP(fake_N)

        # Identity mapping of images
        img_N_id = self.g_PN(img_N)
        img_P_id = self.g_NP(img_P)
        # Discriminators determine validity of translated images
        valid_N = self.d_N(fake_N)
        valid_P = self.d_P(fake_P)

        class_N_loss = self.classifier(fake_N)
        class_P_loss = self.classifier(fake_P)
        # Adversarial loss
        valid_target = torch.ones_like(valid_N)
        adversarial_loss = nn.MSELoss()(valid_N, valid_target) + nn.MSELoss()(valid_P, valid_target)

        # Cycle consistency loss
        cycle_loss = nn.L1Loss()(reconstr_N, img_N) + nn.L1Loss()(reconstr_P, img_P)

        # Identity loss
        identity_loss = nn.L1Loss()(img_N_id, img_N) + nn.L1Loss()(img_P_id, img_P)

        # Classifier loss
        class_loss = nn.MSELoss()(class_N_loss, torch.ones_like(class_N_loss)) + nn.MSELoss()(class_P_loss, torch.zeros_like(class_P_loss))

        # Total generator loss
        total_loss = adversarial_loss + self.lambda_cycle * cycle_loss + self.lambda_id * identity_loss + self.classifier_weight * class_loss
              
        self.log('adversarial_loss', adversarial_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('cycle_loss', cycle_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('identity_loss', identity_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('class_loss', class_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('generator_loss', total_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        opt.zero_grad()
        self.manual_backward(total_loss)
        opt.step()
        self.untoggle_optimizer(opt)
        
        return total_loss, adversarial_loss, cycle_loss

    def discriminator_training_step(self, img_N, img_P, opt):
        # Pass real images through discriminator D_N
        self.toggle_optimizer(opt)
        pred_real_N = self.d_N(img_N)
        # Compute MSE loss for real Negative images
        mse_real_N = nn.MSELoss()(pred_real_N, torch.ones_like(pred_real_N))

        # Pass fake images from positive to discriminator D_N
        fake_P = self.g_PN(img_P)
        pred_fake_N = self.d_N(fake_P)
        # Compute MSE loss for fake images in domain P
        mse_fake_N = nn.MSELoss()(pred_fake_N, torch.zeros_like(pred_fake_N))
        # Pass real images through discriminator D_P
        pred_real_P = self.d_P(img_P)
        # Compute MSE loss for real images in domain P
        mse_real_P = nn.MSELoss()(pred_real_P, torch.ones_like(pred_real_P))

        # Pass fake images from domain N to discriminator D_P
        fake_N = self.g_NP(img_N)  # Detach to prevent backpropagation to generator
        pred_fake_P = self.d_P(fake_N)
        # Compute MSE loss for fake images in domain N
        mse_fake_P = nn.MSELoss()(pred_fake_P, torch.zeros_like(pred_fake_P))
        

        # Compute total discriminator loss
        dis_loss = 0.5 * (mse_real_N + mse_fake_N + mse_real_P + mse_fake_P)
        opt.zero_grad()
        self.manual_backward(mse_fake_P)
        opt.step()
        self.untoggle_optimizer(opt)
        
        self.log('mse_fake_N', mse_fake_N, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('mse_fake_P', mse_fake_P, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('discriminator_loss', dis_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return dis_loss, mse_fake_N, mse_fake_P
    
    def training_step(self, batch, batch_idx):
        img_N, img_P = batch
        optD, optG = self.optimizers()
        
        total_loss, adversarial_loss, cycle_loss = self.generator_training_step(img_N, img_P, optG)
        dis_loss, mse_fake_N, mse_fake_P = self.discriminator_training_step(img_N, img_P, optD) 
        
        return {"generator_loss": total_loss, "adversarial_loss": adversarial_loss, "cycle_loss": cycle_loss, "discriminator_loss": dis_loss, "mse_fake_N": mse_fake_N, "mse_fake_P": mse_fake_P}
    
    def validation_step(self, batch, batch_idx):
        img_N, img_P = batch

        # Translate images to the other domain
        fake_P = self.g_NP(img_N)
        fake_N = self.g_PN(img_P)

        # Translate images back to original domain
        reconstr_N = self.g_PN(fake_P)
        reconstr_P = self.g_NP(fake_N)

        # Identity mapping of images
        img_N_id = self.g_PN(img_N)
        img_P_id = self.g_NP(img_P)

        # Discriminators determine validity of translated images
        valid_N = self.d_N(fake_N)
        valid_P = self.d_P(fake_P)

        class_N_loss = self.classifier(fake_N)
        class_P_loss = self.classifier(fake_P)

        # Adversarial loss
        valid_target = torch.ones_like(valid_N)
        adversarial_loss = nn.MSELoss()(valid_N, valid_target) + nn.MSELoss()(valid_P, valid_target)

        # Cycle consistency loss
        cycle_loss = nn.L1Loss()(reconstr_N, img_N) + nn.L1Loss()(reconstr_P, img_P)

        # Identity loss
        identity_loss = nn.L1Loss()(img_N_id, img_N) + nn.L1Loss()(img_P_id, img_P)

        # Classifier loss
        class_loss = nn.MSELoss()(class_N_loss, torch.ones_like(class_N_loss)) + nn.MSELoss()(class_P_loss, torch.zeros_like(class_P_loss))

        # Total generator loss
        total_loss = adversarial_loss + self.lambda_cycle * cycle_loss + self.lambda_id * identity_loss + self.classifier_weight * class_loss

        self.log('val_adversarial_loss', adversarial_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_cycle_loss', cycle_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_identity_loss', identity_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_class_loss', class_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_generator_loss', total_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        return total_loss

    def configure_optimizers(self):
        optG = torch.optim.Adam(itertools.chain(self.g_NP.parameters(), self.g_PN.parameters()),lr=2e-4, betas=(0.5, 0.999))
        optD = torch.optim.Adam(itertools.chain(self.d_N.parameters(), self.d_P.parameters()),lr=2e-4, betas=(0.5, 0.999))
        
        gamma = lambda epoch: 1 - max(0, epoch + 1 - 100) / 101
        schD = LambdaLR(optD, lr_lambda=gamma)
#         Optimizer= [optD, optG]
        return optD, optG

    def train_dataloader(self):
        root_dir = "/kaggle/working/CycleGan-CFE/train-data/train"
        train_N = "0"
        train_P = "1"
        img_res = (image_size, image_size)

        dataset = CustomDataset(root_dir=root_dir, train_N=train_N, train_P=train_P, img_res=img_res)

        # Set up DataLoader for parallel processing and GPU acceleration
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

        return dataloader
    
    def val_dataloader(self):
        root_dir = "/kaggle/working/CycleGan-CFE/train-data/val"
        train_N = "0"
        train_P = "1"
        img_res = (image_size, image_size)

        dataset = CustomDataset(root_dir=root_dir, train_N=train_N, train_P=train_P, img_res=img_res)

        # Set up DataLoader for parallel processing and GPU acceleration
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

        return dataloader
     

    def on_train_batch_end(self, outputs, batch, batch_idx):
        if batch_idx % 100 == 0:
            # Get a random batch from the test dataloader
            batch = next(iter(test_dataloader))
            img_N, img_P = batch

            # Pick a random image from the batch
            idx = np.random.randint(img_N.size(0))
            img_N = img_N[idx].unsqueeze(0).to('cuda')
            img_P = img_P[idx].unsqueeze(0).to('cuda')
            # Translate images to the other domain
            fake_P = self.g_NP(img_N)
            fake_N = self.g_PN(img_P)

            # Translate images back to original domain
            reconstr_N = self.g_PN(fake_P)
            reconstr_P = self.g_NP(fake_N)

            # Plot the images
            fig, axes = plt.subplots(2, 3, figsize=(15, 10))

            # Plot real N, translated P, and reconstructed N
            axes[0, 0].imshow(img_N.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray')
            axes[0, 0].set_title("Real N")
            axes[0, 0].axis('off')

            axes[0, 1].imshow(fake_P.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray')
            axes[0, 1].set_title("Translated P")
            axes[0, 1].axis('off')

            axes[0, 2].imshow(reconstr_N.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray')
            axes[0, 2].set_title("Reconstructed N")
            axes[0, 2].axis('off')

            # Plot real P, translated N, and reconstructed P
            axes[1, 0].imshow(img_P.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray')
            axes[1, 0].set_title("Real P")
            axes[1, 0].axis('off')

            axes[1, 1].imshow(fake_N.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray')
            axes[1, 1].set_title("Translated N")
            axes[1, 1].axis('off')

            axes[1, 2].imshow(reconstr_P.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), cmap='gray')
            axes[1, 2].set_title("Reconstructed P")
            axes[1, 2].axis('off')

            # Log the figure in WandB
            wandb.log({"test_images": wandb.Image(fig)})

            plt.close(fig)

In [15]:
from torch.optim.lr_scheduler import LambdaLR
import itertools
from lightning.pytorch.callbacks import ModelCheckpoint

cyclegan = CycleGAN(gf=32, df=64, classifier_path='/kaggle/input/swin-tiny/pytorch/v0/1/swin_t-epoch00-val_loss0.35.ckpt', classifier_weight=1)

checkpoint_callback = ModelCheckpoint(dirpath="/kaggle/working/CycleGan-CFE/models",
                                      filename='cyclegan-epoch_{epoch}-vloss_{val_generator_loss:.2f}.ckpt',
                                      monitor='val_generator_loss',
                                      save_top_k=3,
                                      save_last=True,
                                      save_weights_only=True,
                                      verbose=True,
                                      mode='min')

testdata_dir = "/kaggle/working/CycleGan-CFE/train-data/val"
train_N = "0"
train_P = "1"
img_res = (image_size, image_size)

test_dataset = CustomDataset(root_dir=testdata_dir, train_N=train_N, train_P=train_P, img_res=img_res)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

wandb_logger = WandbLogger(project="CycleGAN-CFE", name="GAN-training",log_model="all")
# Create the trainer
trainer = pl.Trainer(
    accelerator="auto",
    max_epochs=2,
    log_every_n_steps=1,
    benchmark=True,
    devices="auto",
    logger=wandb_logger,
    callbacks= [checkpoint_callback]
)

# Train the CycleGAN model
trainer.fit(cyclegan)


freezing Classifier.


INFO: Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs


/opt/conda/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /kaggle/working/CycleGan-CFE/models exists and is not empty.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
INFO: 
  | Name       | Type             | Params
------------------------------------------------
0 | g_NP       | ResUNetGenerator | 15.0 M
1 | g_PN       | ResUNetGenerator | 15.0 M
2 | d_N        | Discriminator    | 2.8 M 
3 | d_P        | Discriminator    | 2.8 M 
4 | classifier | Classifier       | 27.7 M
------------------------------------------------
35.5 M    Trainable params
27.7 M    Non-trainable params
63.2 M    Total params
252.837   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
pl.__version__

In [14]:
import torch
import torch.nn as nn

class AttentionGate(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AttentionGate, self).__init__()
        self.conv_gate = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.conv_x = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x, g):
        gate = self.conv_gate(g)
        x = self.conv_x(x)
        attention = self.softmax(gate)
        x_att = x * attention
        return x_att

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv(x)
        x_pool = self.pool(x)
        return x, x_pool

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Up, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.attention = AttentionGate(in_channels // 2, out_channels)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x2 = self.attention(x2, x1)
        x = torch.cat([x1, x2], dim=1)
        x = self.conv(x)
        return x

class AttentionUNetGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(AttentionUNetGenerator, self).__init__()
        self.down1 = Down(in_channels, 64)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)
        self.down4 = Down(256, 512)
        self.up1 = Up(512, 256)
        self.up2 = Up(512, 128)
        self.up3 = Up(256, 64)
        self.up4 = nn.Conv2d(128, out_channels, kernel_size=1, stride=1)

    def forward(self, x):
        x1, x = self.down1(x)
        x2, x = self.down2(x)
        x3, x = self.down3(x)
        x4, x = self.down4(x)
        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = self.up4(x)
        return x
    
# Example usage
generator = AttentionGate(in_channels=512, out_channels=512)
input_tensor = torch.randn(1, 512, 16, 16)
g = torch.randn(1, 512, 16, 16)
output_tensor = generator(input_tensor, g)
print(output_tensor.shape)  # Output: torch.Size([1, 1, 512, 512])

torch.Size([1, 512, 16, 16])


In [40]:
import torch
import torch.nn as nn

class AttentionGate(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AttentionGate, self).__init__()
        self.conv_gate = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.conv_x = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x, g):
        gate = self.conv_gate(g)
        x = self.conv_x(x)
        attention = self.softmax(gate)
        x_att = x * attention
        return x_att
    
class UNetGenerator(nn.Module):
    def __init__(self, img_shape, gf, channels):
        super(UNetGenerator, self).__init__()
        self.img_shape = img_shape
        self.channels = channels
        
        # Downsampling layers
        self.conv1 = nn.Sequential(
            nn.Conv2d(channels, gf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.GroupNorm(num_groups=1, num_channels=gf)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(gf, gf * 2, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.GroupNorm(num_groups=1, num_channels=gf * 2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(gf * 2, gf * 4, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.GroupNorm(num_groups=1, num_channels=gf * 4)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(gf * 4, gf * 8, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.GroupNorm(num_groups=1, num_channels=gf * 8)
        )

        self.attn_layer = nn.ModuleList([
            AttentionGate(gf * 2**(i), gf * 2**(i+1))
            for i in range(3)
        ])

        # Upsampling layers
        self.deconv1 = nn.Sequential(
            nn.ConvTranspose2d(gf * 8, gf * 4, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.GroupNorm(num_groups=1, num_channels=gf * 4)
        )
        self.deconv2 = nn.Sequential(
            nn.ConvTranspose2d(gf * 8, gf * 2, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.GroupNorm(num_groups=1, num_channels=gf * 2)
        )
        self.deconv3 = nn.Sequential(
            nn.ConvTranspose2d(gf * 4, gf, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.GroupNorm(num_groups=1, num_channels=gf)
        )
        self.deconv4 = nn.Sequential(
            nn.ConvTranspose2d(gf * 2, channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x):
        # Downsampling
        d1 = self.conv1(x)
        d2 = self.conv2(d1)
        d3 = self.conv3(d2)
        d4 = self.conv4(d3)
        
        # Upsampling
        u1 = self.deconv1(d4)
        u1 = self.attn_layer[2](d3, u1)
        
        u2 = self.deconv2(u1)
        u2 = self.attn_layer[1](d2, u2)
        
        u3 = self.deconv3(u2)
        u3 = self.attn_layer[0](d1, u3)
        
        output = self.deconv4(u3)
        
        return output

model = UNetGenerator(img_shape=(1, 512, 512), gf=32, channels=1)
summary(model=model, input_size=(1, 512, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 256, 256]             544
         LeakyReLU-2         [-1, 32, 256, 256]               0
         GroupNorm-3         [-1, 32, 256, 256]              64
            Conv2d-4         [-1, 64, 128, 128]          32,832
         LeakyReLU-5         [-1, 64, 128, 128]               0
         GroupNorm-6         [-1, 64, 128, 128]             128
            Conv2d-7          [-1, 128, 64, 64]         131,200
         LeakyReLU-8          [-1, 128, 64, 64]               0
         GroupNorm-9          [-1, 128, 64, 64]             256
           Conv2d-10          [-1, 256, 32, 32]         524,544
        LeakyReLU-11          [-1, 256, 32, 32]               0
        GroupNorm-12          [-1, 256, 32, 32]             512
  ConvTranspose2d-13          [-1, 128, 64, 64]         524,416
             ReLU-14          [-1, 128,

In [60]:
import torch
import torch.nn as nn
import lightning as pl
import wandb

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.0):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)


    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out += residual
        out = self.relu(out)
        return out
    
class AttentionGate(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AttentionGate, self).__init__()
        self.conv_gate = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.conv_x = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x, g):
        gate = self.conv_gate(g)
        x = self.conv_x(x)
        attention = self.softmax(gate)
        x_att = x * attention
        return x_att
    

class ResUNetGenerator(pl.LightningModule):
    def __init__(self, gf, channels, dropout_rate=0.3):
        super(ResUNetGenerator, self).__init__()
        self.gf = gf
        self.channels = channels
        self.dropout_rate = dropout_rate

        # Define the layers for the encoder
        self.conv2d = nn.Conv2d(channels, gf, kernel_size=4, stride=2, padding=1, padding_mode='reflect')
        
        self.conv2d_layers_left = nn.ModuleList([
            nn.Conv2d(gf * 2**i, gf * 2**(i+1), kernel_size=4, stride=2, padding=1, padding_mode='reflect')
            for i in range(3)
        ])
        
        self.attn_layer = nn.ModuleList([
            AttentionGate(gf * 2**(i), gf * 2**(i))
            for i in range(3)
        ])
        
        self.groupNorm_layers = nn.ModuleList([
            nn.GroupNorm(8, gf * 2**(i+1))
            for i in range(3)
        ])
        
        self.res_blocks_left = nn.ModuleList([
            ResidualBlock(gf * 2**(i+1), gf * 2**(i+1), dropout_rate)
            for i in range(3)
        ])
        


        # Define the layers for the decoder
        self.deconv2d_layers = nn.ModuleList([
            nn.ConvTranspose2d(gf * 2**(3-i), gf * 2**(2-i), kernel_size=4, stride=2, padding=1)
            for i in range(3)
        ])

        self.deconv2d_final = nn.ConvTranspose2d(gf, channels, kernel_size=4, stride=2, padding=1)
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.group_norm = nn.GroupNorm(8, gf)
        self.sig = nn.Tanh()

    def forward(self, x):
        
        d0 = self.leaky_relu(self.group_norm(self.conv2d(x)))
        d1 = self.leaky_relu(self.groupNorm_layers[0](self.conv2d_layers_left[0](d0)))
        d1 = self.res_blocks_left[0](d1)
        
        d2 = self.leaky_relu(self.groupNorm_layers[1](self.conv2d_layers_left[1](d1)))
        d2 = self.res_blocks_left[1](d2)
        
        d3 = self.leaky_relu(self.groupNorm_layers[2](self.conv2d_layers_left[2](d2)))
        d3 = self.res_blocks_left[2](d3)
    
        
        u1 = self.deconv2d_layers[0](d3)
        u1 = self.attn_layer[2](u1, d2)

        u2 = self.deconv2d_layers[1](u1)
        u2 = self.attn_layer[1](u2, d1)
        
        u3 = self.deconv2d_layers[2](u2)
        u3 = self.attn_layer[0](u3, d0)

        output_img = self.sig(self.deconv2d_final(u3))

        return output_img

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0002, betas=(0.5, 0.999))
        return optimizer

In [61]:
from torchsummary import summary

In [62]:
model = ResUNetGenerator(gf=32, channels=1)
summary(model=model, input_size=(1, 512, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 256, 256]             544
         GroupNorm-2         [-1, 32, 256, 256]              64
         LeakyReLU-3         [-1, 32, 256, 256]               0
            Conv2d-4         [-1, 64, 128, 128]          32,832
         GroupNorm-5         [-1, 64, 128, 128]             128
         LeakyReLU-6         [-1, 64, 128, 128]               0
            Conv2d-7         [-1, 64, 128, 128]          36,864
              ReLU-8         [-1, 64, 128, 128]               0
            Conv2d-9         [-1, 64, 128, 128]          36,864
             ReLU-10         [-1, 64, 128, 128]               0
    ResidualBlock-11         [-1, 64, 128, 128]               0
           Conv2d-12          [-1, 128, 64, 64]         131,200
        GroupNorm-13          [-1, 128, 64, 64]             256
        LeakyReLU-14          [-1, 128,