In [None]:
!pip install torchvision




In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision.transforms import ToTensor, Lambda


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
class MirroredStrategy:
    def __init__(self, devices):
        self.devices = devices
        self.device_count = len(devices)

    def __enter__(self):
        return self

    def __exit__(self, *args):
        pass

    def scope(self):
        return torch.cuda.device(self.devices[0])

In [None]:
# Replace the TensorFlow code with PyTorch
mirrored_strategy = MirroredStrategy(devices=["cuda:4", "cuda:5", "cuda:6", "cuda:7"])

tar_path = '/kaggle/input/brain-t1-and-t1c-scans/T1c Cropped 3D-T/T1c_imgs_middle_only/'
src_path = '/kaggle/input/brain-t1-and-t1c-scans/T1 Cropped 3D-T/T1_imgs_middle_only/'
img_size = (145, 184)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DownsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, apply_batchnorm=True):
        super(DownsampleBlock, self).__init__()
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(out_channels) if apply_batchnorm else nn.Identity(),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, apply_dropout=False):
        super(UpsampleBlock, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        if apply_dropout:
            layers.append(nn.Dropout(0.5))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.down_stack = nn.ModuleList([
            DownsampleBlock(3, 64, False),
            DownsampleBlock(64, 128),
            DownsampleBlock(128, 256),
            DownsampleBlock(256, 512),
            DownsampleBlock(512, 512),
            DownsampleBlock(512, 512),
            DownsampleBlock(512, 512),
            DownsampleBlock(512, 512)
        ])

        self.up_stack = nn.ModuleList([
            UpsampleBlock(512, 512, True),
            UpsampleBlock(1024, 512, True),
            UpsampleBlock(1024, 512, True),
            UpsampleBlock(1024, 512),
            UpsampleBlock(512, 256),
            UpsampleBlock(256, 128),
            UpsampleBlock(128, 64),
        ])

        self.last = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        skips = []
        for down in self.down_stack:
            x = down(x)
            skips.append(x)
        skips = reversed(skips[:-1])

        for up, skip in zip(self.up_stack, skips):
            x = up(x)
            x = torch.cat([x, skip], dim=1)

        x = self.last(x)
        return torch.tanh(x)

# Example usage:
gen = Generator()
print(gen)


Generator(
  (down_stack): ModuleList(
    (0): DownsampleBlock(
      (block): Sequential(
        (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (1): Identity()
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): DownsampleBlock(
      (block): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): DownsampleBlock(
      (block): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (3): DownsampleBlock(
      (block): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1

In [None]:
gen = Generator().to(device)


In [None]:
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.down1 = DownsampleBlock(6, 64, False)
        self.down2 = DownsampleBlock(64, 128)
        self.down3 = DownsampleBlock(128, 256)

        self.zero_pad1 = nn.ZeroPad2d((1, 0, 1, 0))
        self.conv = nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1, bias=False)

        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)
        self.zero_pad2 = nn.ZeroPad2d((1, 0, 1, 0))

        self.last = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)

    def forward(self, input, target):
        x = torch.cat([input, target], dim=1)

        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)

        x = self.zero_pad1(x)
        x = self.conv(x)

        x = self.leaky_relu(x)
        x = self.zero_pad2(x)

        x = self.last(x)

        return x

# Example usage:
discriminator = Discriminator()
print(discriminator)


Discriminator(
  (down1): DownsampleBlock(
    (block): Sequential(
      (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): Identity()
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
  (down2): DownsampleBlock(
    (block): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
  (down3): DownsampleBlock(
    (block): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
  (zero_pad1): ZeroPad2d((1, 0, 1, 0))
  (conv): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
  (leaky_relu): LeakyReLU(negative_slope=0.2, inplace

In [None]:
import torch
import torch.nn.functional as F
from torch.nn import L1Loss

def generator_loss(disc_generated_output, gen_output, target, lambda_value):
    mae = L1Loss()

    gan_loss = F.binary_cross_entropy_with_logits(disc_generated_output, torch.ones_like(disc_generated_output))
    l1_loss = mae(gen_output, target)
    total_gen_loss = gan_loss + (lambda_value * l1_loss)

    return total_gen_loss, gan_loss, l1_loss



def discriminator_loss(disc_real_output, disc_generated_output):
    mae = L1Loss()

    real_loss = mae(torch.ones_like(disc_real_output), disc_real_output)
    generated_loss = mae(torch.zeros_like(disc_generated_output), disc_generated_output)
    total_disc_loss = real_loss + generated_loss

    return total_disc_loss


generator_optimizer = optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))


In [None]:
def train_step(input_image, target, epoch):
    gen.train()
    disc.train()

    # Move data to device (if using GPU)
    input_image, target = input_image.to(device), target.to(device)

    # Generator forward pass
    gen_output = gen(input_image)

    # Discriminator forward pass
    disc_real_output = disc(torch.cat([input_image, target], dim=1))
    disc_generated_output = disc(torch.cat([input_image, gen_output.detach()], dim=1))

    # Compute losses
    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target, LAMBDA)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    # Generator backward pass and optimization
    generator_optimizer.zero_grad()
    gen_total_loss.backward()
    generator_optimizer.step()

    # Discriminator backward pass and optimization
    discriminator_optimizer.zero_grad()
    disc_loss.backward()
    discriminator_optimizer.step()

    return gen_total_loss.item(), disc_loss.item()


In [None]:
def fit(train_loader, test_loader, epochs):
    for epoch in range(epochs):
        start = time.time()

        directory = 'output'
        if not os.path.exists(directory):
            os.makedirs(directory)

        for input_, target in test_loader:
            save_images(gen, input_, target, epoch)

        # Train
        print(f"Epoch {epoch}")
        for n, (input_, target) in enumerate(train_loader):
            input_ = input_.to(device)
            target = target.to(device)
            gen_loss, disc_loss = train_step(input_, target, epoch)

        gen_loss_value = gen_loss.item()
        disc_loss_value = disc_loss.item()

        gen_losses.append(gen_loss_value)
        disc_losses.append(disc_loss_value)

        if epoch % 5 == 0:
            torch.save({
                'generator_state_dict': gen.state_dict(),
                'discriminator_state_dict': disc.state_dict(),
                'gen_optimizer_state_dict': generator_optimizer.state_dict(),
                'disc_optimizer_state_dict': discriminator_optimizer.state_dict(),
            }, f'checkpoint_epoch_{epoch}.pt')

        print("Generator loss: {:.2f}".format(gen_loss_value))
        print("Discriminator loss: {:.2f}".format(disc_loss_value))
        print("Time taken for epoch {} is {} sec\n".format(epoch+1, time.time() - start))


In [None]:
def normalize(img):
    img = (img / 127.5) - 1
    return img

def resize(img):
    # Use torchvision.transforms.Resize for resizing
    transform = nn.Upsample(size=(256, 256), mode='nearest')
    img = transform(img)
    return img

In [None]:
pip install natsort


Collecting natsort
  Obtaining dependency information for natsort from https://files.pythonhosted.org/packages/ef/82/7a9d0550484a62c6da82858ee9419f3dd1ccc9aa1c26a1e43da3ecd20b0d/natsort-8.4.0-py3-none-any.whl.metadata
  Downloading natsort-8.4.0-py3-none-any.whl.metadata (21 kB)
Downloading natsort-8.4.0-py3-none-any.whl (38 kB)
Installing collected packages: natsort
Successfully installed natsort-8.4.0
Note: you may need to restart the kernel to use updated packages.


In [None]:
from os import path, listdir


In [None]:
src_list = []
tar_list = []

for filename in all_files:
    img = plt.imread(src_path + filename)
    img = normalize(img)
    img = resize(img)
    src_list.append(img)

for filename in all_files:

    img = plt.imread(tar_path + filename)
    img = normalize(img)
    img = resize(img)
    tar_list.append(img)


NameError: name 'all_files' is not defined

In [None]:
src_list = torch.stack(src_list).to(device)
tar_list = torch.stack(tar_list).to(device)

In [None]:
dataset = TensorDataset(src_list, tar_list)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
torch.save(gen.state_dict(), 'generator55.pth')

In [None]:
for i in range(len(inputs)):
    plt.figure(figsize=(15, 15))
    display_list = [inputs[i][0], targets[i][0], predictions[i][0]]

    # Structural similarity calculation for PyTorch tensors
    target_np = (targets[i][0].cpu().numpy() + 1) / 2.0
    prediction_np = (predictions[i][0].cpu().numpy() + 1) / 2.0
    score, diff = structural_similarity(target_np, prediction_np, full=True, win_size=3, data_range=2)
    score = f'{score:.3f}'
    titles = ['Input Image', 'Ground Truth', f'Predicted Image: {score}']

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(titles[i])
        plt.imshow(display_list[i].cpu().numpy().transpose(1, 2, 0) * 0.5 + 0.5)
        plt.axis('off')


In [None]:
loaded_dict = torch.load('checkpoint_epoch_10.pth')
gen.load_state_dict(loaded_dict['generator_state_dict'])
disc.load_state_dict(loaded_dict['discriminator_state_dict'])
generator_optimizer.load_state_dict(loaded_dict['gen_optimizer_state_dict'])
discriminator_optimizer.load_state_dict(loaded_dict['disc_optimizer_state_dict'])


In [None]:
plt.plot(gen_losses, label='gen_loss')
plt.plot(disc_losses, label='disc_loss')
plt.legend()
plt.show()
