In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2
from tqdm import tqdm
# Define a custom dataset
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

class TotalVariationLoss(nn.Module):
    def __init__(self):
        super(TotalVariationLoss, self).__init__()

    def forward(self, x):
        batch_size = x.size(0)
        h_x = x.size(2)
        w_x = x.size(3)
        count_h = self.tensor_size(x[:,:,1:,:])
        count_w = self.tensor_size(x[:,:,:,1:])
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        return 2*(h_tv/count_h+w_tv/count_w)/batch_size

    @staticmethod
    def tensor_size(t):
        return t.size(1)*t.size(2)*t.size(3)


class CustomDataset(Dataset):
    def __init__(self, path,transform=None):
        self.transform = transform
        self.data=os.listdir(path)


    def __getitem__(self, index):
        # TODO: return one item on the index
        x = cv2.imread(f'patches/{self.data[index]}')
        
        if self.transform:
            x = self.transform(x)
        return x

    def __len__(self):
        # TODO: return the data length
        return len(self.data)

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.GaussianBlur(3),
    transforms.Normalize((0.5,), (0.5,)),
])

# Initialize the dataset
dataset = CustomDataset("patches",transform=transform)
# Initialize the dataloader
dataloader = DataLoader(dataset, batch_size=5, shuffle=True)




# Define the generator
class Generator(nn.Module):
    def __init__(self, z_dim=10, img_dim=64*64):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, img_dim),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.gen(x)

# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self, img_dim=64*64):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)

# Hyperparameters
lr = 0.001
z_dim = 64
img_dim = 64*64
batch_size = 32
num_epochs = 1000

# Initialize generator and discriminator
gen = Generator(z_dim, img_dim)
disc = Discriminator(img_dim)

# Loss and optimizers
criterion = nn.BCELoss()
tv=TotalVariationLoss()
opt_gen = torch.optim.Adam(gen.parameters(), lr=lr)
opt_disc = torch.optim.Adam(disc.parameters(), lr=lr)

# Training loop
for epoch in tqdm(range(num_epochs)):
    for batch_idx, (real) in enumerate(dataloader):
        real = real.view(-1, 4096)
        batch_size = real.shape[0]
        noise = torch.randn(batch_size, z_dim)
        fake = gen(noise)
        if(epoch%7==0):
            # Train Discriminator
            disc_real = disc(real).view(-1)
            lossD_real = criterion(disc_real, torch.ones_like(disc_real))
            disc_fake = disc(fake).view(-1)
            lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
            lossD = (lossD_real + lossD_fake) / 2
            disc.zero_grad()
            lossD.backward(retain_graph=True)
            opt_disc.step()

        # Train Generator
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))+tv(fake.reshape(-1,1,64,64))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx==0 and epoch%10==0:
            print("Epoch: ",epoch,"| GenLoss: ",lossG.item(),"| DiscLoss: ",lossD.item())
            with torch.no_grad():
                fake = gen(torch.randn(batch_size, z_dim)).reshape(-1,1,64,64)
                real_np=real.reshape(-1,1,64,64)[0][0].cpu().numpy()
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                print(img_grid_fake.shape)

                img_np = fake[0][0].cpu().numpy()
                plt.imshow(img_np,cmap='Greys')
                plt.show()
                plt.imshow(real_np,cmap='Greys')
                plt.show()

                # plt.waitforbuttonpress()
                



In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2
from tqdm import tqdm
# Define a custom dataset
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

class TotalVariationLoss(nn.Module):
    def __init__(self):
        super(TotalVariationLoss, self).__init__()

    def forward(self, x):
        batch_size = x.size(0)
        h_x = x.size(2)
        w_x = x.size(3)
        count_h = self.tensor_size(x[:,:,1:,:])
        count_w = self.tensor_size(x[:,:,:,1:])
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        return 2*(h_tv/count_h+w_tv/count_w)/batch_size

    @staticmethod
    def tensor_size(t):
        return t.size(1)*t.size(2)*t.size(3)


class CustomDataset(Dataset):
    def __init__(self, path,transform=None):
        self.transform = transform
        self.data=os.listdir(path)


    def __getitem__(self, index):
        # TODO: return one item on the index
        x = cv2.imread(f'patches/{self.data[index]}')
        
        if self.transform:
            x = self.transform(x)
        return x

    def __len__(self):
        # TODO: return the data length
        return len(self.data)

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.GaussianBlur(3),
    transforms.Normalize((0.5,), (0.5,)),
])

# Initialize the dataset
dataset = CustomDataset("patches",transform=transform)
# Initialize the dataloader
dataloader = DataLoader(dataset, batch_size=5, shuffle=True)




# Define the generator
class Generator(nn.Module):
    def __init__(self, z_dim=10, img_dim=64*64):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, img_dim),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.gen(x)

# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self, img_dim=64*64):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)

# Hyperparameters
lr = 0.001
z_dim = 64
img_dim = 64*64
batch_size = 32
num_epochs = 300

# Initialize generator and discriminator
gen = Generator(z_dim, img_dim)
disc = Discriminator(img_dim)

# Loss and optimizers
criterion = nn.BCELoss()
tv=TotalVariationLoss()
opt_gen = torch.optim.Adam(gen.parameters(), lr=lr)
opt_disc = torch.optim.Adam(disc.parameters(), lr=lr)

# Training loop
for epoch in tqdm(range(num_epochs)):
    for batch_idx, (real) in enumerate(dataloader):
        real = real.view(-1, 4096)
        batch_size = real.shape[0]
        noise = torch.randn(batch_size, z_dim)
        fake = gen(noise)
        if(epoch%7==0):
            # Train Discriminator
            disc_real = disc(real).view(-1)
            lossD_real = criterion(disc_real, torch.ones_like(disc_real))
            disc_fake = disc(fake).view(-1)
            lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
            lossD = (lossD_real + lossD_fake) / 2
            disc.zero_grad()
            lossD.backward(retain_graph=True)
            opt_disc.step()

        # Train Generator
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))+0.1*tv(fake.reshape(-1,1,64,64))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx==0 and epoch%10==0:
            print("Epoch: ",epoch,"| GenLoss: ",lossG.item(),"| DiscLoss: ",lossD.item())
            with torch.no_grad():
                fake = gen(torch.randn(batch_size, z_dim)).reshape(-1,1,64,64)
                real_np=real.reshape(-1,1,64,64)[0][0].cpu().numpy()
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                print(img_grid_fake.shape)

                img_np = fake[0][0].cpu().numpy()
                plt.imshow(img_np,cmap='Greys')
                plt.show()
                plt.imshow(real_np,cmap='Greys')
                plt.show()

                # plt.waitforbuttonpress()
                



In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2
from tqdm import tqdm
import torch.nn.functional as F
# Define a custom dataset
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

class TotalVariationLoss(nn.Module):
    def __init__(self):
        super(TotalVariationLoss, self).__init__()

    def forward(self, x):
        batch_size = x.size(0)
        h_x = x.size(2)
        w_x = x.size(3)
        count_h = self.tensor_size(x[:,:,1:,:])
        count_w = self.tensor_size(x[:,:,:,1:])
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        return 2*(h_tv/count_h+w_tv/count_w)/batch_size

    @staticmethod
    def tensor_size(t):
        return t.size(1)*t.size(2)*t.size(3)


class CustomDataset(Dataset):
    def __init__(self, path,transform=None):
        self.transform = transform
        self.data=os.listdir(path)


    def __getitem__(self, index):
        # TODO: return one item on the index
        x = cv2.imread(f'patches/{self.data[index]}')
        
        if self.transform:
            x = self.transform(x)
        return x

    def __len__(self):
        # TODO: return the data length
        return len(self.data)

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.GaussianBlur(3),
    transforms.Normalize((0.5,), (0.5,)),
])

# Initialize the dataset
dataset = CustomDataset("patches",transform=transform)
# Initialize the dataloader
dataloader = DataLoader(dataset, batch_size=5, shuffle=True)




# Define the generator
class Generator(nn.Module):
    def __init__(self,z_dim, img_dim):
        super(Generator,self).__init__()
        self.tconv1 = nn.ConvTranspose2d(z_dim,1024,kernel_size = 4,stride = 1)
        self.bn1 = nn.BatchNorm2d(1024);

        self.tconv2 = nn.ConvTranspose2d(1024,512,kernel_size = 4,stride = 2,padding = 1)
        self.bn2 = nn.BatchNorm2d(512);

        self.tconv3 = nn.ConvTranspose2d(512,256,kernel_size = 4,stride = 2,padding = 1)
        self.bn3 = nn.BatchNorm2d(256);

        self.tconv3_a = nn.ConvTranspose2d(256,128,kernel_size = 4,stride = 2,padding = 1)
        self.bn3_a = nn.BatchNorm2d(128);

        self.tconv4 = nn.ConvTranspose2d(128,1,kernel_size = 4,stride = 2,padding = 1)
        self._init_weights()
    def forward(self,x): #(100,1,1)
        # print(x.shape)
        x = F.leaky_relu(self.bn1(self.tconv1(x)), negative_slope=0.02, inplace=False)# (B,1024,4,4)
        # print(x.shape)
        x = F.leaky_relu(self.bn2(self.tconv2(x)), negative_slope=0.02, inplace=False)# (B,512,8,8)
        # print(x.shape)
        x = F.leaky_relu(self.bn3(self.tconv3(x)), negative_slope=0.02, inplace=False)# (B,256,16,16)
        # print(x.shape)
        x = F.leaky_relu(self.bn3_a(self.tconv3_a(x)), negative_slope=0.02, inplace=False)# (B,256,16,16)
        # print(x.shape)
        x = F.tanh((self.tconv4(x))) # (B,1,256,256)
        # print(x.shape)
        return x

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m,(nn.ConvTranspose2d)):
                nn.init.normal_(m.weight.data,0,0.02)
            if isinstance(m,(nn.BatchNorm2d)):
                nn.init.normal_(m.weight.data,1.0,0.02)
                nn.init.constant_(m.bias.data,0)




# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self, img_dim=64*64):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)

# Hyperparameters
lr = 0.001
z_dim = 64
img_dim = 64*64
batch_size = 32
num_epochs = 300

# Initialize generator and discriminator
gen = Generator(z_dim, img_dim)
disc = Discriminator(img_dim)

# Loss and optimizers
criterion = nn.BCELoss()
tv=TotalVariationLoss()
opt_gen = torch.optim.Adam(gen.parameters(), lr=lr)
opt_disc = torch.optim.Adam(disc.parameters(), lr=lr)

# Training loop
for epoch in tqdm(range(num_epochs)):
    for batch_idx, (real) in enumerate(dataloader):
        real = real.view(-1, 4096)
        batch_size = real.shape[0]
        noise = torch.randn(batch_size, z_dim,1,1)
        
        fake = gen(noise)
        if(epoch%7==0):
            # Train Discriminator
            disc_real = disc(real).view(-1)
            lossD_real = criterion(disc_real, torch.ones_like(disc_real))
            disc_fake = disc(fake.reshape(-1,4096)).view(-1)
            lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
            lossD = (lossD_real + lossD_fake) / 2
            disc.zero_grad()
            lossD.backward(retain_graph=True)
            opt_disc.step()

        # Train Generator
        output = disc(fake.reshape(-1,4096)).view(-1)
        lossG = criterion(output, torch.ones_like(output))+tv(fake.reshape(-1,1,64,64))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx==0 and epoch%10==0:
            print("Epoch: ",epoch,"| GenLoss: ",lossG.item(),"| DiscLoss: ",lossD.item())
            with torch.no_grad():
                fake = gen(torch.randn(batch_size, z_dim,1,1)).reshape(-1,1,64,64)
                real_np=real.reshape(-1,1,64,64)[0][0].cpu().numpy()
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                # print(img_grid_fake.shape)

                img_np = fake[0][0].cpu().numpy()
                plt.imshow(img_np,cmap='Greys')
                plt.show()
                plt.imshow(real_np,cmap='Greys')
                plt.show()

                # plt.waitforbuttonpress()
                



In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2
from tqdm import tqdm
import torch.nn.functional as F
# Define a custom dataset
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

class TotalVariationLoss(nn.Module):
    def __init__(self):
        super(TotalVariationLoss, self).__init__()

    def forward(self, x):
        batch_size = x.size(0)
        h_x = x.size(2)
        w_x = x.size(3)
        count_h = self.tensor_size(x[:,:,1:,:])
        count_w = self.tensor_size(x[:,:,:,1:])
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        return 2*(h_tv/count_h+w_tv/count_w)/batch_size

    @staticmethod
    def tensor_size(t):
        return t.size(1)*t.size(2)*t.size(3)


class CustomDataset(Dataset):
    def __init__(self, path,transform=None):
        self.transform = transform
        self.data=os.listdir(path)


    def __getitem__(self, index):
        # TODO: return one item on the index
        x = cv2.imread(f'patches/{self.data[index]}')
        
        if self.transform:
            x = self.transform(x)
        return x

    def __len__(self):
        # TODO: return the data length
        return len(self.data)

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.GaussianBlur(3),
    transforms.Normalize((0.5,), (0.5,)),
])

# Initialize the dataset
dataset = CustomDataset("patches",transform=transform)
# Initialize the dataloader
dataloader = DataLoader(dataset, batch_size=5, shuffle=True)




import torch
from torch import nn

class Generator(nn.Module):
    def __init__(self, z_dim=10, channels_img=1, features_g=64):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            self._block1(z_dim, features_g * 16, 3, 1, 1),  # img: 4x4
            self._block(features_g * 16, features_g * 8, 3, 1, 1),  # img: 8x8
            self._block(features_g * 8, features_g * 4, 3, 1, 1),  # img: 16x16
            self._block(features_g * 4, features_g * 2, 3, 1, 1),  # img: 32x32
             # Output: N x channels_img x 64 x 64
            
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(
                 features_g * 2, channels_img, 3, 1, 1, bias=False,
            ),
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    def _block1(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Upsample(scale_factor=4, mode='bilinear'),
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.gen(x)







# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self, img_dim=64*64):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)

# Hyperparameters
lr = 0.001
z_dim = 64
img_dim = 64*64
batch_size = 32
num_epochs = 300
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize generator and discriminator
gen = Generator(z_dim).to(device)
disc = Discriminator(img_dim).to(device)

# Loss and optimizers
criterion = nn.BCELoss()
tv=TotalVariationLoss()
opt_gen = torch.optim.Adam(gen.parameters(), lr=lr)
opt_disc = torch.optim.Adam(disc.parameters(), lr=lr)

# Training loop
for epoch in tqdm(range(num_epochs)):
    for batch_idx, (real) in enumerate(dataloader):
        real = real.view(-1, 4096).to(device)
        batch_size = real.shape[0]
        noise = torch.randn(batch_size, z_dim,1,1).to(device)
        
        fake = gen(noise)
        if(epoch%7==0):
            # Train Discriminator
            disc_real = disc(real).view(-1)
            lossD_real = criterion(disc_real, torch.ones_like(disc_real))
            disc_fake = disc(fake.reshape(-1,4096)).view(-1)
            lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
            lossD = (lossD_real + lossD_fake) / 2
            disc.zero_grad()
            lossD.backward(retain_graph=True)
            opt_disc.step()

        # Train Generator
        output = disc(fake.reshape(-1,4096)).view(-1)
        lossG = criterion(output, torch.ones_like(output))#+tv(fake.reshape(-1,1,64,64))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx==0 and epoch%10==0:
            print("Epoch: ",epoch,"| GenLoss: ",lossG.item(),"| DiscLoss: ",lossD.item())
            with torch.no_grad():
                fake = gen(torch.randn(batch_size, z_dim,1,1).to(device)).reshape(-1,1,64,64)
                real_np=real.reshape(-1,1,64,64)[0][0].cpu().numpy()
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                # print(img_grid_fake.shape)

                img_np = fake[0][0].cpu().numpy()
                plt.imshow(img_np,cmap='Greys')
                plt.show()
                plt.imshow(real_np,cmap='Greys')
                plt.show()

                # plt.waitforbuttonpress()
                



In [None]:
torch.save(gen.state_dict(), "NoduleGenerator.pt")

In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2
from tqdm import tqdm
import torch.nn.functional as F
# Define a custom dataset
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

class TotalVariationLoss(nn.Module):
    def __init__(self):
        super(TotalVariationLoss, self).__init__()

    def forward(self, x):
        batch_size = x.size(0)
        h_x = x.size(2)
        w_x = x.size(3)
        count_h = self.tensor_size(x[:,:,1:,:])
        count_w = self.tensor_size(x[:,:,:,1:])
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        return 2*(h_tv/count_h+w_tv/count_w)/batch_size

    @staticmethod
    def tensor_size(t):
        return t.size(1)*t.size(2)*t.size(3)


class CustomDataset(Dataset):
    def __init__(self, path,transform=None):
        self.transform = transform
        self.data=os.listdir(path)


    def __getitem__(self, index):
        # TODO: return one item on the index
        x = cv2.imread(f'patches/{self.data[index]}')
        
        if self.transform:
            x = self.transform(x)
        return x

    def __len__(self):
        # TODO: return the data length
        return len(self.data)

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.GaussianBlur(3),
    transforms.Normalize((0.5,), (0.5,)),
])

# Initialize the dataset
dataset = CustomDataset("patches",transform=transform)
# Initialize the dataloader
dataloader = DataLoader(dataset, batch_size=5, shuffle=True)




import torch
from torch import nn

class GeneratorA(nn.Module):
    def __init__(self, z_dim=10, channels_img=1, features_g=64):
        super(GeneratorA, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            self._block1(z_dim, features_g * 16, 3, 1, 1),  # img: 4x4
            self._block(features_g * 16, features_g * 8, 3, 1, 1),  # img: 8x8
            self._block(features_g * 8, features_g * 4, 3, 1, 1),  # img: 16x16
            self._block(features_g * 4, features_g * 2, 3, 1, 1),  # img: 32x32
             # Output: N x channels_img x 64 x 64
            
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(
                 features_g * 2, channels_img, 3, 1, 1, bias=False,
            ),
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    def _block1(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Upsample(scale_factor=4, mode='bilinear'),
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.gen(x)



# Define the generator
class GeneratorB(nn.Module):
    def __init__(self, z_dim=10, img_dim=64*64):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, img_dim),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.gen(x)



# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self, img_dim=64*64):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)

# Hyperparameters
lr = 0.001
z_dim = 64
img_dim = 64*64
batch_size = 32
num_epochs = 300
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize generator and discriminator
genA = GeneratorA(z_dim).to(device)
genB= GeneratorB(z_dim).to(device)
disc = Discriminator(img_dim).to(device)

# Loss and optimizers
criterion = nn.BCELoss()
tv=TotalVariationLoss()
opt_gen = torch.optim.Adam(gen.parameters(), lr=lr)
opt_disc = torch.optim.Adam(disc.parameters(), lr=lr)
alpha=0.08
# Training loop
for epoch in tqdm(range(num_epochs)):
    for batch_idx, (real) in enumerate(dataloader):
        real = real.view(-1, 4096).to(device)
        batch_size = real.shape[0]
        noise = torch.randn(batch_size, z_dim,1,1).to(device)
        
        fake = genA(noise)+alpha*genB(noise.reshape(-1,z_dim)).reshape(-1,64,64)
        if(epoch%7==0):
            # Train Discriminator
            disc_real = disc(real).view(-1)
            lossD_real = criterion(disc_real, torch.ones_like(disc_real))
            disc_fake = disc(fake.reshape(-1,4096)).view(-1)
            lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
            lossD = (lossD_real + lossD_fake) / 2
            disc.zero_grad()
            lossD.backward(retain_graph=True)
            opt_disc.step()

        # Train Generator
        output = disc(fake.reshape(-1,4096)).view(-1)
        lossG = criterion(output, torch.ones_like(output))+tv(fake)
        genA.zero_grad()
        genB.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx==0 and epoch%10==0:
            print("Epoch: ",epoch,"| GenLoss: ",lossG.item(),"| DiscLoss: ",lossD.item())
            with torch.no_grad():
                noise=torch.randn(batch_size, z_dim,1,1).to(device)
                fake = genA(noise).reshape(-1,1,64,64)+alpha*genB(noise.reshape(-1,64)).reshape(-1,64,64)
                real_np=real.reshape(-1,1,64,64)[0][0].cpu().numpy()
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                # print(img_grid_fake.shape)

                img_np = fake[0][0].cpu().numpy()
                plt.imshow(img_np,cmap='Greys')
                plt.show()
                plt.imshow(real_np,cmap='Greys')
                plt.show()

                # plt.waitforbuttonpress()
                



In [None]:
torch.save(genA.state_dict(), "NoduleGeneratorA.pt")
torch.save(genB.state_dict(), "NoduleGeneratorB.pt")

