In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2
from tqdm import tqdm
import torch.nn.functional as F
import pandas as pd
from sklearn.mixture import GaussianMixture
import random

def interpolate_nodule(nodule_tensor, S):
    S=abs(int(S.item()))
    if(S==0):
        S=2
    nodule_resized = F.interpolate(nodule_tensor.unsqueeze(0).unsqueeze(1), size=(S, S), mode='bilinear', align_corners=False)
    return nodule_resized

def add_nodule_to_image(larger_image, nodule, x, y):
    x=(x/8).int()
    y=(y/8).int()
    larger_image=larger_image.unsqueeze(1)
    nodule_height, nodule_width = nodule.shape[2:]
    # print(nodule_height, nodule_width)
    larger_height, larger_width = larger_image.shape[2:]
    # print(larger_height, larger_width)


    # If the nodule doesn't fit, crop the nodule to fit within the boundaries
    if x + nodule_width/2 > larger_width:
        x=x-20
    if y + nodule_height/2 > larger_height:
        y=y-20
    if x - nodule_width/2 < 0:
       x=20
    if y - nodule_height/2 < 0:
        y=20
    
    # Add the nodule to the larger image at location (x, y)
    # print(larger_image.shape)
    # print(nodule.shape)
    # print(x,y)
    if(nodule_height%2==0 and nodule_width%2==0):

        larger_image[:, :, y- nodule_height//2:y + nodule_height//2, x-nodule_width//2:x + nodule_width//2] += nodule
    if(nodule_height%2!=0 and nodule_width%2==0):

        larger_image[:, :, y- nodule_height//2:y + nodule_height//2+1, x-nodule_width//2:x + nodule_width//2] += nodule
    if(nodule_height%2==0 and nodule_width%2!=0):

        larger_image[:, :, y- nodule_height//2:y + nodule_height//2, x-nodule_width//2:x + nodule_width//2+1] += nodule
    if(nodule_height%2!=0 and nodule_width%2!=0):

        larger_image[:, :, y- nodule_height//2:y + nodule_height//2+1, x-nodule_width//2:x + nodule_width//2+1] += nodule
    
    return larger_image

class NoduleGenerator(nn.Module):
    def __init__(self, z_dim=10, channels_img=1, features_g=64):
        super(NoduleGenerator, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            self._block1(z_dim, features_g * 16, 3, 1, 1),  # img: 4x4
            self._block(features_g * 16, features_g * 8, 3, 1, 1),  # img: 8x8
            self._block(features_g * 8, features_g * 4, 3, 1, 1),  # img: 16x16
            self._block(features_g * 4, features_g * 2, 3, 1, 1),  # img: 32x32
             # Output: N x channels_img x 64 x 64
            
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(
                 features_g * 2, channels_img, 3, 1, 1, bias=False,
            ),
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    def _block1(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Upsample(scale_factor=4, mode='bilinear'),
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.gen(x)

# Refinement Network
class RefinementNetwork(nn.Module):
    def __init__(self):
        super(RefinementNetwork, self).__init__()
        # Define the architecture for the refinement network (example)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_image_with_nodule):
        x = self.relu(self.conv1(input_image_with_nodule))
        x = self.relu(self.conv2(x))
        x = self.sigmoid(self.conv3(x))
        return x

# Discriminator (for GAN training)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Define the architecture for the discriminator (example)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(256*256*64, 1)  # Adapt input size to match flattened conv output
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_image):
        x = self.relu(self.conv1(input_image))
        x = self.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten for FC layer
        x = self.sigmoid(self.fc(x))
        return x

# # Define the training process
# def train(refinement_network, discriminator,nodule_generator, dataloader, num_epochs, device):
#     refinement_network.to(device)
#     discriminator.to(device)
#     nodule_generator.to(device)

#     optimizer_refinement = optim.Adam(refinement_network.parameters(), lr=0.0002, betas=(0.5, 0.999))
#     optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
#     criterion = nn.BCELoss()

#     for epoch in range(num_epochs):
#         for idx, data in enumerate(dataloader):
            
#             with torch.no_grad():
#                 clean,nodule,loc=data
#                 clean.to(device)
#                 nodule.to(device)
#                 loc.to(device)
#                 nodules = nodule_generator(torch.randn(data.shape[0], 64)).reshape(-1,64,64).to(device)
#                 for i in range(nodules.shape[0]):
#                     interpolated_nodule = interpolate_nodule(nodules[i], loc[i,2])
#                     clean[i]=add_nodule_to_image(clean[i],interpolated_nodule,loc[i,0],loc[i,1])
                
#             refinement_network.zero_grad()
#             fake_output = refinement_network(clean)
#             refinement_loss = torch.mean(torch.abs(fake_output - clean))
#             refinement_loss.backward()
#             optimizer_refinement.step()

#             if(idx%4==0):
#                 discriminator.zero_grad()
#                 real_output = discriminator(nodule)
#                 fake_output = discriminator(refinement_network(clean))
#                 real_loss = criterion(real_output, torch.ones_like(real_output))
#                 fake_loss = criterion(fake_output, torch.zeros_like(fake_output))
#                 discriminator_loss = (real_loss + fake_loss) / 2
#                 discriminator_loss.backward()
#                 optimizer_discriminator.step()
import torchvision.transforms.functional as TF

def gaussian_mask(size, sigma):
    """
    Generates a 2D Gaussian mask.

    Parameters:
    - size (int): Size of the mask (both height and width).
    - sigma (float): Standard deviation of the Gaussian distribution.

    Returns:
    - torch.Tensor: 2D Gaussian mask.
    """
    ax = torch.arange(-size // 2 + 1., size // 2 + 1.)
    xx, yy = torch.meshgrid([ax, ax])
    kernel = torch.exp(-(xx**2 + yy**2) / (2. * sigma**2))
    return kernel 

def create_gaussian_mask(size=64, sigma=16):
    """
    Creates a 64x64 Gaussian mask.

    Parameters:
    - size (int): Size of the mask (both height and width).
    - sigma (float): Standard deviation of the Gaussian distribution.

    Returns:
    - torch.Tensor: 64x64 Gaussian mask.
    """
    mask = gaussian_mask(size, sigma)
    return mask


import matplotlib.pyplot as plt

def train(refinement_network, discriminator, nodule_generator, dataloader, num_epochs, device):
    refinement_network.to(device)
    discriminator.to(device)
    nodule_generator.to(device)

    optimizer_refinement = optim.Adam(refinement_network.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    criterion = nn.BCELoss()

    refinement_losses = []
    discriminator_losses = []
    gauss_mask=(-1)*(create_gaussian_mask().to(device))
    print("Oggy")
    for epoch in range(num_epochs):
        for idx, data in tqdm(enumerate(dataloader)):
            with torch.no_grad():
                clean, nodule, loc = data
                clean = clean.to(device)
                nodule = nodule.to(device)
                loc = loc.to(device)
                nodules = nodule_generator(torch.randn(clean.shape[0], 64,1,1).to(device)).reshape(-1, 64, 64).to(device)
                for i in range(nodules.shape[0]):
                    # print("Mask",gauss_mask)
                    # print(nodules[i])
                    
                    interpolated_nodule = interpolate_nodule(torch.mul(gauss_mask,nodules[i]), loc[i,0,2])
                    
                    
                    
                    clean[i] = add_nodule_to_image(clean[i], interpolated_nodule, loc[i, 0,0], loc[i,0, 1])

            refinement_network.zero_grad()
            fake_output = refinement_network(clean)
            refinement_loss = torch.mean(torch.abs(fake_output - clean))
            refinement_loss.backward()
            optimizer_refinement.step()

            refinement_losses.append(refinement_loss.item())

            if (17*epoch+idx+1) % 25 == 0:
                discriminator.zero_grad()
                real_output = discriminator(nodule)
                fake_output1 = discriminator(refinement_network(clean))
                real_loss = criterion(real_output, torch.ones_like(real_output))
                fake_loss = criterion(fake_output1, torch.zeros_like(fake_output1))
                discriminator_loss = (real_loss + fake_loss) / 2
                discriminator_loss.backward()
                optimizer_discriminator.step()

                discriminator_losses.append(discriminator_loss.item())

            if (17*epoch+idx+1)>25 and idx % 8==0:
                print(f"Epoch [{epoch + 1}/{num_epochs}], Iteration [{idx + 1}/{len(dataloader)}]")
                print(f"Refinement Loss: {refinement_loss.item()}, Discriminator Loss: {discriminator_loss.item()}")
                clean_img=clean[0].cpu().detach().squeeze().numpy()
                x=int(loc[0,0,0]/8)
                y=int(loc[0,0,1]/8)
                s=int(loc[0,0,2]/2)
                cv2.rectangle(clean_img,(x-s,y-s),(x+s,y+s),(0,255,0),3)
                plt.imshow(clean_img)
                # plt.imshow(gauss_mask.cpu().detach().numpy())
                # Visualize sample results
                plt.figure(figsize=(10, 5))
                plt.subplot(1, 2, 1)
                plt.title('Clean Image')
                plt.imshow(clean[0].cpu().detach().squeeze().numpy(), cmap='gray')
                plt.subplot(1, 2, 2)
                plt.title('Refined Image')
                plt.imshow(fake_output[0].cpu().detach().squeeze().numpy(), cmap='gray')
                plt.show()

    # Plot the losses
    plt.figure(figsize=(10, 5))
    plt.plot(refinement_losses, label='Refinement Loss')
    plt.plot(discriminator_losses, label='Discriminator Loss')
    plt.xlabel('Iterations')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()





class CustomDataset(Dataset):
    def __init__(self, path_nodules,path_clean,gmm,transform=None):
        self.transform = transform
        self.gmm=gmm
        path_nodules=os.listdir("images\images")
        self.nodule_data=path_nodules
        self.clean_data=path_clean
        
        


    def __getitem__(self, index):
        # TODO: return one item on the index

        clean = cv2.imread(f'images/images/{random.choice(self.clean_data)}',cv2.IMREAD_GRAYSCALE)
        nodule = cv2.imread(f'images/images/{random.choice(self.nodule_data)}',cv2.IMREAD_GRAYSCALE)
        new_samples = gmm.sample(1)
        
        if self.transform:
            clean = self.transform(clean)
            nodule = self.transform(nodule)

        return clean,nodule,torch.from_numpy(new_samples[0])

    def __len__(self):
        # TODO: return the data length
        return len(self.nodule_data)

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256,256)),
    transforms.GaussianBlur(3),
    transforms.Normalize((0.5,), (0.5,)),
])

csv=pd.read_csv("jsrt_metadata.csv") # To learn location of Data
data=[csv[csv["state"]!="non-nodule"]["x"].tolist(),csv[csv["state"]!="non-nodule"]["y"].tolist(),csv[csv["state"]!="non-nodule"]["size"].tolist()]
data=np.array(data).T
num_components = 3  # Number of Gaussian components

# Create and train the Gaussian Mixture Model
gmm = GaussianMixture(n_components=num_components, covariance_type='full')
gmm.fit(data)

# Initialize the dataset
dataset = CustomDataset(csv[csv["state"]!="non-nodule"]["study_id"].tolist(),csv[csv["state"]=="non-nodule"]["study_id"].tolist(),gmm,transform=transform)
# Initialize the dataloader
dataloader = DataLoader(dataset, batch_size=15, shuffle=True)




# Training

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("run")
refinement_network = RefinementNetwork()
discriminator = Discriminator()
nodulegen=NoduleGenerator(64)
nodulegen.load_state_dict(torch.load("NoduleGenerator.pt",map_location="cpu"))
# Train the models
print("Training")
train(refinement_network, discriminator,nodulegen ,dataloader, num_epochs=10, device=device)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2
from tqdm import tqdm
import torch.nn.functional as F
import pandas as pd
from sklearn.mixture import GaussianMixture
import random

def interpolate_nodule(nodule_tensor, S):
    S=abs(int(S.item()))
    if(S==0):
        S=2
    nodule_resized = F.interpolate(nodule_tensor.unsqueeze(0).unsqueeze(1), size=(S, S), mode='bilinear', align_corners=False)
    return nodule_resized

def add_nodule_to_image(larger_image, nodule, x, y):
    x=(x/8).int()
    y=(y/8).int()
    larger_image=larger_image.unsqueeze(1)
    nodule_height, nodule_width = nodule.shape[2:]
    # print(nodule_height, nodule_width)
    larger_height, larger_width = larger_image.shape[2:]
    # print(larger_height, larger_width)


    # If the nodule doesn't fit, crop the nodule to fit within the boundaries
    if x + nodule_width/2 > larger_width:
        x=x-20
    if y + nodule_height/2 > larger_height:
        y=y-20
    if x - nodule_width/2 < 0:
       x=20
    if y - nodule_height/2 < 0:
        y=20
    
    # Add the nodule to the larger image at location (x, y)
    # print(larger_image.shape)
    # print(nodule.shape)
    # print(x,y)
    if(nodule_height%2==0 and nodule_width%2==0):

        larger_image[:, :, y- nodule_height//2:y + nodule_height//2, x-nodule_width//2:x + nodule_width//2] += nodule
    if(nodule_height%2!=0 and nodule_width%2==0):

        larger_image[:, :, y- nodule_height//2:y + nodule_height//2+1, x-nodule_width//2:x + nodule_width//2] += nodule
    if(nodule_height%2==0 and nodule_width%2!=0):

        larger_image[:, :, y- nodule_height//2:y + nodule_height//2, x-nodule_width//2:x + nodule_width//2+1] += nodule
    if(nodule_height%2!=0 and nodule_width%2!=0):

        larger_image[:, :, y- nodule_height//2:y + nodule_height//2+1, x-nodule_width//2:x + nodule_width//2+1] += nodule
    
    return larger_image

class NoduleGenerator(nn.Module):
    def __init__(self, z_dim=10, channels_img=1, features_g=64):
        super(NoduleGenerator, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            self._block1(z_dim, features_g * 16, 3, 1, 1),  # img: 4x4
            self._block(features_g * 16, features_g * 8, 3, 1, 1),  # img: 8x8
            self._block(features_g * 8, features_g * 4, 3, 1, 1),  # img: 16x16
            self._block(features_g * 4, features_g * 2, 3, 1, 1),  # img: 32x32
             # Output: N x channels_img x 64 x 64
            
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(
                 features_g * 2, channels_img, 3, 1, 1, bias=False,
            ),
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    def _block1(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Upsample(scale_factor=4, mode='bilinear'),
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.gen(x)

# Refinement Network
class RefinementNetwork(nn.Module):
    def __init__(self):
        super(RefinementNetwork, self).__init__()
        # Define the architecture for the refinement network (example)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=65, out_channels=1, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_image_with_nodule):
        x = self.relu(self.conv1(input_image_with_nodule))
        x = self.relu(self.conv2(x))
        x = self.sigmoid(self.conv3(torch.cat([x,input_image_with_nodule],dim=1)))
        return x

# Discriminator (for GAN training)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Define the architecture for the discriminator (example)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(256*256*64, 1)  # Adapt input size to match flattened conv output
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_image):
        x = self.relu(self.conv1(input_image))
        x = self.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten for FC layer
        x = self.sigmoid(self.fc(x))
        return x

# # Define the training process
# def train(refinement_network, discriminator,nodule_generator, dataloader, num_epochs, device):
#     refinement_network.to(device)
#     discriminator.to(device)
#     nodule_generator.to(device)

#     optimizer_refinement = optim.Adam(refinement_network.parameters(), lr=0.0002, betas=(0.5, 0.999))
#     optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
#     criterion = nn.BCELoss()

#     for epoch in range(num_epochs):
#         for idx, data in enumerate(dataloader):
            
#             with torch.no_grad():
#                 clean,nodule,loc=data
#                 clean.to(device)
#                 nodule.to(device)
#                 loc.to(device)
#                 nodules = nodule_generator(torch.randn(data.shape[0], 64)).reshape(-1,64,64).to(device)
#                 for i in range(nodules.shape[0]):
#                     interpolated_nodule = interpolate_nodule(nodules[i], loc[i,2])
#                     clean[i]=add_nodule_to_image(clean[i],interpolated_nodule,loc[i,0],loc[i,1])
                
#             refinement_network.zero_grad()
#             fake_output = refinement_network(clean)
#             refinement_loss = torch.mean(torch.abs(fake_output - clean))
#             refinement_loss.backward()
#             optimizer_refinement.step()

#             if(idx%4==0):
#                 discriminator.zero_grad()
#                 real_output = discriminator(nodule)
#                 fake_output = discriminator(refinement_network(clean))
#                 real_loss = criterion(real_output, torch.ones_like(real_output))
#                 fake_loss = criterion(fake_output, torch.zeros_like(fake_output))
#                 discriminator_loss = (real_loss + fake_loss) / 2
#                 discriminator_loss.backward()
#                 optimizer_discriminator.step()
import torchvision.transforms.functional as TF

def gaussian_mask(size, sigma):
    """
    Generates a 2D Gaussian mask.

    Parameters:
    - size (int): Size of the mask (both height and width).
    - sigma (float): Standard deviation of the Gaussian distribution.

    Returns:
    - torch.Tensor: 2D Gaussian mask.
    """
    ax = torch.arange(-size // 2 + 1., size // 2 + 1.)
    xx, yy = torch.meshgrid([ax, ax])
    kernel = torch.exp(-(xx**2 + yy**2) / (2. * sigma**2))
    return kernel 

def create_gaussian_mask(size=64, sigma=16):
    """
    Creates a 64x64 Gaussian mask.

    Parameters:
    - size (int): Size of the mask (both height and width).
    - sigma (float): Standard deviation of the Gaussian distribution.

    Returns:
    - torch.Tensor: 64x64 Gaussian mask.
    """
    mask = gaussian_mask(size, sigma)
    return mask


import matplotlib.pyplot as plt

def train(refinement_network, discriminator, nodule_generator, dataloader, num_epochs, device):
    refinement_network.to(device)
    discriminator.to(device)
    nodule_generator.to(device)

    optimizer_refinement = optim.Adam(refinement_network.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    criterion = nn.BCELoss()

    refinement_losses = []
    discriminator_losses = []
    gauss_mask=(-1)*(create_gaussian_mask().to(device))
    print("Oggy")
    for epoch in range(num_epochs):
        for idx, data in tqdm(enumerate(dataloader)):
            with torch.no_grad():
                clean, nodule, loc = data
                clean = clean.to(device)
                nodule = nodule.to(device)
                loc = loc.to(device)
                nodules = nodule_generator(torch.randn(clean.shape[0], 64,1,1).to(device)).reshape(-1, 64, 64).to(device)
                for i in range(nodules.shape[0]):
                    # print("Mask",gauss_mask)
                    # print(nodules[i])
                    
                    interpolated_nodule = interpolate_nodule(torch.mul(gauss_mask,nodules[i]), loc[i,0,2])
                    
                    
                    
                    clean[i] = add_nodule_to_image(clean[i], interpolated_nodule, loc[i, 0,0], loc[i,0, 1])

            refinement_network.zero_grad()
            fake_output = refinement_network(clean)
            refinement_loss = torch.mean(torch.abs(fake_output - clean))
            refinement_loss.backward()
            optimizer_refinement.step()

            refinement_losses.append(refinement_loss.item())

            if idx%8 == 0:
                discriminator.zero_grad()
                real_output = discriminator(nodule)
                fake_output1 = discriminator(refinement_network(clean))
                real_loss = criterion(real_output, torch.ones_like(real_output))
                fake_loss = criterion(fake_output1, torch.zeros_like(fake_output1))
                discriminator_loss = (real_loss + fake_loss) / 2
                discriminator_loss.backward()
                optimizer_discriminator.step()

                discriminator_losses.append(discriminator_loss.item())

            if  idx % 8==0:
                print(f"Epoch [{epoch + 1}/{num_epochs}], Iteration [{idx + 1}/{len(dataloader)}]")
                print(f"Refinement Loss: {refinement_loss.item()}, Discriminator Loss: {discriminator_loss.item()}")
                clean_img=clean[0].cpu().detach().squeeze().numpy()
                x=int(loc[0,0,0]/8)
                y=int(loc[0,0,1]/8)
                s=int(loc[0,0,2]/2)
                cv2.rectangle(clean_img,(x-s,y-s),(x+s,y+s),(0,255,0),3)
                plt.imshow(clean_img)
                # plt.imshow(gauss_mask.cpu().detach().numpy())
                # Visualize sample results
                plt.figure(figsize=(10, 5))
                plt.subplot(1, 2, 1)
                plt.title('Clean Image')
                plt.imshow(clean[0].cpu().detach().squeeze().numpy(), cmap='gray')
                plt.subplot(1, 2, 2)
                plt.title('Refined Image')
                plt.imshow(fake_output[0].cpu().detach().squeeze().numpy(), cmap='gray')
                plt.show()

    # Plot the losses
    plt.figure(figsize=(10, 5))
    plt.plot(refinement_losses, label='Refinement Loss')
    plt.plot(discriminator_losses, label='Discriminator Loss')
    plt.xlabel('Iterations')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()





class CustomDataset(Dataset):
    def __init__(self, path_nodules,path_clean,gmm,transform=None):
        self.transform = transform
        self.gmm=gmm
        path_nodules=os.listdir("images\images")
        self.nodule_data=path_nodules
        self.clean_data=path_clean
        
        


    def __getitem__(self, index):
        # TODO: return one item on the index

        clean = cv2.imread(f'images/images/{random.choice(self.clean_data)}',cv2.IMREAD_GRAYSCALE)
        nodule = cv2.imread(f'images/images/{random.choice(self.nodule_data)}',cv2.IMREAD_GRAYSCALE)
        new_samples = gmm.sample(1)
        
        if self.transform:
            clean = self.transform(clean)
            nodule = self.transform(nodule)

        return clean,nodule,torch.from_numpy(new_samples[0])

    def __len__(self):
        # TODO: return the data length
        return len(self.nodule_data)

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256,256)),
    transforms.GaussianBlur(3),
    transforms.Normalize((0.5,), (0.5,)),
])

csv=pd.read_csv("jsrt_metadata.csv")
data=[csv[csv["state"]!="non-nodule"]["x"].tolist(),csv[csv["state"]!="non-nodule"]["y"].tolist(),csv[csv["state"]!="non-nodule"]["size"].tolist()]
data=np.array(data).T
num_components = 3  # Number of Gaussian components

# Create and train the Gaussian Mixture Model
gmm = GaussianMixture(n_components=num_components, covariance_type='full')
gmm.fit(data)

# Initialize the dataset
dataset = CustomDataset(csv[csv["state"]!="non-nodule"]["study_id"].tolist(),csv[csv["state"]=="non-nodule"]["study_id"].tolist(),gmm,transform=transform)
# Initialize the dataloader
dataloader = DataLoader(dataset, batch_size=15, shuffle=True)




In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("run")
refinement_network = RefinementNetwork()
discriminator = Discriminator()
nodulegen=NoduleGenerator(64)
nodulegen.load_state_dict(torch.load("NoduleGenerator.pt",map_location="cpu"))
# Train the models
print("Training")
train(refinement_network, discriminator,nodulegen ,dataloader, num_epochs=10, device=device)