In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
import torch.nn as nn
import os
from sklearn.model_selection import train_test_split
import torchvision 
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn.init as init

In [2]:
if torch.cuda.is_available():
    torch.cuda.set_device(0)  # Set the current device to the first GPU
    print("Using GPU")
else:
    print("Using CPU")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class VGG16_MRI(nn.Module):
    def __init__(self, num_classes=2):
        super(VGG16_MRI, self).__init__()
        # Load a pre-trained VGG16 model with batch normalization
        model = torchvision.models.vgg16_bn(pretrained=True)
        
        # Change the first convolutional layer to accept single-channel (grayscale) input
        model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        
        # Retain the feature extraction layers
        self.feature = model.features
        
        # Define the feature dimension based on output size for 240x240 input
        # VGG16 feature output will be (512, 7, 7) for 224x224, so we calculate for 240x240
        self.feat_dim = 512 * 7 * 7  # Update this if output size changes with input size
        
        # Adjust the number of classes for binary classification (0 or 1)
        self.num_classes = num_classes
        
        # Batch normalization layer
        self.bn = nn.BatchNorm1d(self.feat_dim)
        self.bn.bias.requires_grad_(False)  # no shift
        
        # Fully connected layer to map features to the number of classes
        self.fc_layer = nn.Linear(self.feat_dim, self.num_classes)
        
        self.model = model
            
    def forward(self, x):
        # Pass input through feature extraction layers
        feature = self.feature(x)
        feature = feature.view(feature.size(0), -1)  # Flatten the feature map
        feature = self.bn(feature)  # Apply batch normalization
        res = self.fc_layer(feature)  # Output class scores
        
        return feature, res

    def predict(self, x):
        # Pass input through feature extraction layers
        feature = self.feature(x)
        feature = feature.view(feature.size(0), -1)  # Flatten the feature map
        feature = self.bn(feature)  # Apply batch normalization
        res = self.fc_layer(feature)  # Output class scores

        return res

Using GPU


In [3]:
# Load the labels
labels_df = pd.read_csv("/kaggle/input/preprocessed-brats23/labels.csv")

# Define the data directory
data_dir = "/kaggle/input/preprocessed-brats23/Images"

# Split into train, validation, and test sets 
train_df, temp_df = train_test_split(labels_df, test_size=0.3, stratify=labels_df['label'], random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['label'], random_state=42)


In [4]:
# Dataset Class
class MRIDataset(Dataset):
    def __init__(self, df, data_dir, transform=None):
        self.df = df
        self.data_dir = data_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = os.path.join(self.data_dir, self.df.iloc[idx]['filename'])
        image = Image.open(img_name).convert('L')  # Convert to grayscale
        label = int(self.df.iloc[idx]['label'])

        if self.transform:
            image = self.transform(image)

        return image, label

# Define image transformations (normalization can be adjusted based on data needs)
transform = transforms.Compose([
    transforms.Resize((240, 240)),  # Ensure image size is 240x240
    transforms.ToTensor()
])

# Create datasets
train_dataset = MRIDataset(train_df, data_dir, transform=transform)
val_dataset = MRIDataset(val_df, data_dir, transform=transform)
test_dataset = MRIDataset(test_df, data_dir, transform=transform)

# Create data loaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [5]:
import torch.optim as optim
from tqdm import tqdm

# Initialize the model, loss function, and optimizer
model = VGG16_MRI(num_classes=2).to('cuda')  # Use GPU if available
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
def train() :
    num_epochs = 10
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels in tqdm(train_loader):
            images = images.to('cuda')
    
            labels = labels.to('cuda')
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            _, outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            
            # Track loss and accuracy
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        # Calculate training loss and accuracy
        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = 100 * correct / total
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.2f}%')
    
        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to('cuda')
                labels = labels.to('cuda')
                
                _, outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        # Calculate validation loss and accuracy
        val_loss = val_loss / len(val_loader.dataset)
        val_accuracy = 100 * correct / total
        print(f'Validation Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}%')
    pretrained_VGG_MRI_model = model
    torch.save(model.state_dict(), '/kaggle/working/classifier.pt')

Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /root/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth
100%|██████████| 528M/528M [00:06<00:00, 81.3MB/s]


In [6]:
def load_pretrained_classifier(path=None):
    if path is None:
        path = "/kaggle/input/brats23-classifier/pytorch/default/1/classifier.pt"
    model = VGG16_MRI(num_classes=2)
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

In [7]:
def create_directory_if_not_exists(directory_path):
    try:
        os.makedirs(directory_path, exist_ok=True)
        print(f"Directory created successfully: {directory_path}")
    except OSError as error:
        print(f"Error creating directory: {error}")

In [8]:
# utils.py
def freeze(net):
    for p in net.parameters():
        p.requires_grad_(False) 

def unfreeze(net):
    for p in net.parameters():
        p.requires_grad_(True)

In [9]:
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, in_dim=100, dim=64):
        super(Generator, self).__init__()
        
        def dconv_bn_relu(in_dim, out_dim):
            return nn.Sequential(
                nn.ConvTranspose2d(in_dim, out_dim, 5, 2, padding=2, output_padding=1, bias=False),
                nn.BatchNorm2d(out_dim),
                nn.ReLU())
        
        # Fully connected layer to expand noise to a larger size
        self.l1 = nn.Sequential(
            nn.Linear(in_dim, dim * 8 * 15 * 15, bias=False),
            nn.BatchNorm1d(dim * 8 * 15 * 15),
            nn.ReLU())

        # Deconvolutional layers for upsampling to 240x240
        self.l2_5 = nn.Sequential(
            dconv_bn_relu(dim * 8, dim * 4),   # 15x15 -> 30x30
            dconv_bn_relu(dim * 4, dim * 2),   # 30x30 -> 60x60
            dconv_bn_relu(dim * 2, dim),       # 60x60 -> 120x120
            nn.ConvTranspose2d(dim, 1, 5, 2, padding=2, output_padding=1),  # 120x120 -> 240x240
            nn.Sigmoid())  # Output pixel values in range [0, 1]

    def forward(self, x):
        y = self.l1(x)
        y = y.view(y.size(0), -1, 15, 15)
        y = self.l2_5(y)
        return y



In [10]:
# Discriminator discri.py 
class MinibatchDiscrimination(nn.Module):
    def __init__(self, in_features, out_features, kernel_dims, mean=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.kernel_dims = kernel_dims
        self.mean = mean
        self.T = nn.Parameter(torch.Tensor(in_features, out_features, kernel_dims))
        init.normal(self.T, 0, 1)

    def forward(self, x):
        # x is NxA
        # T is AxBxC
        matrices = x.mm(self.T.view(self.in_features, -1))
        matrices = matrices.view(-1, self.out_features, self.kernel_dims)

        M = matrices.unsqueeze(0)  # 1xNxBxC
        M_T = M.permute(1, 0, 2, 3)  # Nx1xBxC
        norm = torch.abs(M - M_T).sum(3)  # NxNxB
        expnorm = torch.exp(-norm)
        o_b = (expnorm.sum(0) - 1)   # NxB, subtract self distance
        if self.mean:
            o_b /= x.size(0) - 1

        x = torch.cat([x, o_b], 1)
        return x

class MinibatchDiscriminator(nn.Module):
    def __init__(self,in_dim=1, dim=64, n_classes=1000):
        super(MinibatchDiscriminator, self).__init__()
        self.n_classes = n_classes

        def conv_ln_lrelu(in_dim, out_dim, k, s, p):
            return nn.Sequential(
                nn.Conv2d(in_dim, out_dim, k, s, p),
                # Since there is no effective implementation of LayerNorm,
                # we use InstanceNorm2d instead of LayerNorm here.
                nn.InstanceNorm2d(out_dim, affine=True),
                nn.LeakyReLU(0.2))

        self.layer1 = conv_ln_lrelu(in_dim, dim, 5, 2, 2)
        self.layer2 = conv_ln_lrelu(dim, dim*2, 5, 2, 2)
        self.layer3 = conv_ln_lrelu(dim*2, dim*4, 5, 2, 2)
        self.layer4 = conv_ln_lrelu(dim*4, dim*4, 3, 2, 1)
        self.mbd1 = MinibatchDiscrimination(57600, 64, 50)
        self.fc_layer = nn.Linear(57600+64, self.n_classes)

    def forward(self, x):
        out = []
        bs = x.shape[0]
        feat1 = self.layer1(x)
        out.append(feat1)
        feat2 = self.layer2(feat1)
        out.append(feat2)
        feat3 = self.layer3(feat2)
        out.append(feat3)
        feat4 = self.layer4(feat3)
        out.append(feat4)
        feat = feat4.view(bs, -1)
        # print('feat:', feat.shape)
        mb_out = self.mbd1(feat)   # Nx(A+B)
        y = self.fc_layer(mb_out)
        
        return feat, y

In [11]:
def get_GAN(n_classes, z_dim, Pretrained = False):

    # if Pretrained :
        # G= 
    # else :
    G = Generator(z_dim)
    D = MinibatchDiscriminator(n_classes=n_classes)
    
    G = torch.nn.DataParallel(G).to(device)
    D = torch.nn.DataParallel(D).to(device)
    if Pretrained:
        root_path = "/kaggle/input/brats23-gan-epoch75/pytorch/default/1/attack_results"
        dataset_name = "BraTS23"
        mode_name_T = "VGG16_MRI"
        path = os.path.join(root_path, os.path.join(dataset_name, model_name_T))
        # path = os.path.join(os.path.join(gan_model_dir, dataset), target_model)
        path_G = os.path.join(path, "ep75_improved_{}_G.pt".format(dataset_name))
        path_D = os.path.join(path, "ep75_improved_{}_D.pt".format(dataset_name))
        ckp_G = torch.load(path_G)
        G.load_state_dict(ckp_G['state_dict'], strict=True)
        ckp_D = torch.load(path_D)
        D.load_state_dict(ckp_D['state_dict'], strict=True)
        print("Loaded Pretrained Model (Specific GAN)")
    
    return G, D

In [12]:
def get_augmodel():
    # model = pretrained_VGG_MRI_model
    model = load_pretrained_classifier()
    model = torch.nn.DataParallel(model).cuda()
    return model

In [13]:
import time

def init_dataloader(df = None, data_dir="/kaggle/input/preprocessed-brats23/Images", batch_size=64, mode="gan", transform=None, iterator=False):
    tf = time.time()
    if df is None : 
        df = pd.read_csv("/kaggle/input/preprocessed-brats23/labels.csv")
        df,_ = train_test_split(labels_df, test_size=0.4, stratify=labels_df['label'], random_state=42)
    # Define shuffle based on mode (assuming "attack" mode does not shuffle data)
    shuffle_flag = False if mode == "attack" else True

    # Initialize the dataset with the MRIDataset class
    # Define image transformations (normalization can be adjusted based on data needs)
    transform = transforms.Compose([
        transforms.Resize((240, 240)),  # Ensure image size is 240x240
        transforms.ToTensor()
    ])

    dataset = MRIDataset(df=df, data_dir=data_dir, transform=transform)

    # Create the DataLoader
    if iterator:
        data_loader = DataLoader(dataset,
                                 batch_size=batch_size,
                                 shuffle=shuffle_flag,
                                 drop_last=True,
                                 num_workers=0,
                                 pin_memory=True).__iter__()
    else:
        data_loader = DataLoader(dataset,
                                 batch_size=batch_size,
                                 shuffle=shuffle_flag,
                                 drop_last=True,
                                 num_workers=2,
                                 pin_memory=True)

    interval = time.time() - tf
    print(f'Initializing data loader took {interval:.2f} seconds')
    
    return dataset, data_loader


In [14]:
def get_act_reg(train_loader,T,device,Nsample=5000):
    all_fea = []
    with torch.no_grad():
        for batch_idx, data in enumerate(train_loader): # batchsize =100
            # print(data.shape)
            data,_ = data
            if batch_idx*len(data) > Nsample:
                break
            data  = data.to(device)
            fea,_ = T(data)
            if batch_idx == 0:
                all_fea = fea
            else:
                all_fea = torch.cat((all_fea,fea))
    fea_mean = torch.mean(all_fea,dim=0)
    fea_logvar = torch.std(all_fea,dim=0)
    
    print(fea_mean.shape, fea_logvar.shape, all_fea.shape)
    return fea_mean,fea_logvar

In [15]:
#get_attack_model (utils.py)
def get_attack_model(eval_mode=False):
    n_classes=2

    G, D = get_GAN(n_classes=n_classes,z_dim=100)

    dataset = "BraTS23"
    cid = [0]
    # target and student classifiers
    for i in range(len(cid)):
        model = get_augmodel()
        model = model.to(device)
        model = model.eval()
        if i==0:
            targetnets = [model]
        else:
            targetnets.append(model)
    
        # p_reg 
        # if args.loss=='logit_loss: 
        if True :
            # if model_types_[id_] == "IR152" or model_types_[id_]=="VGG16" or model_types_[id_]=="FaceNet64": 
                #target model
#             create_directory_if_not_exists("/kaggle/working/checkpoints/p_reg")
#             p_reg = os.path.join("/kaggle/working/checkpoints/p_reg", '{}_{}_p_reg.pt'.format(dataset,"VGG16_MRI")) #'./p_reg/{}_{}_p_reg.pt'.format(dataset,model_types_[id_])
#             if not os.path.exists(p_reg):
            data_dir = "/kaggle/input/preprocessed-brats23/Images"
            _, dataloader_gan = init_dataloader(df=None,data_dir = data_dir)
                # from attack import get_act_reg
#                 if os.path.isdir(p_reg):
#                     raise ValueError(f"Expected {p_reg} to be a file, but found a directory instead.")
            fea_mean_,fea_logvar_ = get_act_reg(dataloader_gan,model,device)
#                 torch.save({'fea_mean':fea_mean_,'fea_logvar':fea_logvar_},p_reg)
#             else:
#                 fea_reg = torch.load(p_reg)
#                 fea_mean_ = fea_reg['fea_mean']
#                 fea_logvar_ = fea_reg['fea_logvar']
            if i == 0:
                fea_mean = [fea_mean_.to(device)]
                fea_logvar = [fea_logvar_.to(device)]
#             else:
#                 fea_mean.append(fea_mean_)
#                 fea_logvar.append(fea_logvar_)
            # print('fea_logvar_',i,fea_logvar_.shape,fea_mean_.shape)
            
        # else:
        #     fea_mean,fea_logvar = 0,0
    
    # evaluation classifier
    E = get_augmodel()
    E.eval()
    G.eval()
    D.eval()

    return targetnets, E, G, D, n_classes, fea_mean, fea_logvar


In [16]:
# attack.py
def reparameterize(mu, logvar):
    """
    Reparameterization trick to sample from N(mu, var) from
    N(0,1).
    :param mu: (Tensor) Mean of the latent Gaussian [B x D]
    :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
    :return: (Tensor) [B x D]
    """
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)

    return eps * std + mu

def reg_loss(featureT,fea_mean, fea_logvar):
    
    fea_reg = reparameterize(fea_mean, fea_logvar)
    fea_reg = fea_mean.repeat(featureT.shape[0],1)
    loss_reg = torch.mean((featureT - fea_reg).pow(2))
    # print('loss_reg',loss_reg)
    return loss_reg

def iden_loss(T,fake, iden, used_loss,criterion,fea_mean=0, fea_logvar=0,lam=0.1):
    Iden_Loss = 0
    loss_reg = 0
    for tn in T:
      
        feat,out = tn(fake)
        if used_loss == 'logit_loss': #reg only with the target classifier, reg is randomly from distribution
            if Iden_Loss ==0:                
                loss_sdt =  criterion(out, iden)
                loss_reg = lam*reg_loss(feat,fea_mean[0], fea_logvar[0]) #reg only with the target classifier

                Iden_Loss = Iden_Loss + loss_sdt  
            else:                
                loss_sdt =  criterion(out, iden)
                Iden_Loss = Iden_Loss + loss_sdt

        else:
            loss_sdt = criterion(out, iden)
            Iden_Loss = Iden_Loss + loss_sdt

    Iden_Loss = Iden_Loss/len(T) + loss_reg
    return Iden_Loss

In [17]:
def find_criterion(used_loss):
    criterion = None
    if used_loss=='logit_loss':
        criterion = nn.NLLLoss().to(device)
        print('criterion:{}'.format(used_loss))
    elif used_loss=='cel':
        criterion = nn.CrossEntropyLoss().to(device)    
        print('criterion',criterion)
    else:
        print('criterion:{}'.format(used_loss))
    return criterion

In [18]:
def get_deprocessor():
    # resize 240,240
    proc = []
    proc.append(transforms.Resize((240, 240)))
    proc.append(transforms.ToTensor())
    return transforms.Compose(proc)


def low2high(img):
    # Convert from low to high resolution, for grayscale images (1 channel)
    bs = img.size(0)  # Batch size
    proc = get_deprocessor()  # Preprocessing for resizing and tensor conversion
    img_tensor = img.detach().cpu().float()  # Detach and move to CPU, ensure it's a float tensor
    
    # Create a tensor to hold the upscaled images, with 1 channel for grayscale
    img = torch.zeros(bs, 1, 240, 240)  # Change from 3 to 1 channel for grayscale
    
    for i in range(bs):
        # Convert tensor to PIL grayscale image (no RGB conversion)
        img_i = transforms.ToPILImage()(img_tensor[i, :, :, :]).convert('L')  # 'L' mode for grayscale
        img_i = proc(img_i)  # Apply the deprocessing (resize and convert back to tensor)
        img[i, :, :, :] = img_i[:, :, :]  # Assign to output tensor
    
    img = img.cuda()  # Move back to GPU (if available)
    return img

In [19]:

def dist_inversion(G, D, T, E, iden, lr=2e-2, momentum=0.9, lamda=100, \
                   iter_times=1500, clip_range=1.0, improved=False, num_seeds=5, \
                   used_loss='cel', prefix='', random_seed=0, save_img_dir='',fea_mean=0, \
                   fea_logvar=0, lam=0.1, clipz=False):
    
    iden = iden.view(-1).long().to(device)
    criterion = find_criterion(used_loss)
    bs = iden.shape[0]
    
    G.eval() 
    D.eval()
    E.eval()
    
    #NOTE
    mu = Variable(torch.zeros(bs, 100), requires_grad=True)
    log_var = Variable(torch.ones(bs, 100), requires_grad=True)
    
    params = [mu, log_var]
    solver = optim.Adam(params, lr=lr)
    outputs_z = "{}_iter_{}_{}_dis.npy".format(prefix, random_seed, iter_times-1)
    
    if not os.path.exists(outputs_z):
        outputs_z = "{}_iter_{}_{}_dis".format(prefix, random_seed, 0)
        outputs_label = "{}_iter_{}_{}_label".format(prefix, random_seed, 0)
        np.save(outputs_z,{"mu":mu.detach().cpu().numpy(),"log_var":log_var.detach().cpu().numpy()})
        np.save(outputs_label,iden.detach().cpu().numpy())
            
        for i in range(iter_times):
            z = reparameterize(mu, log_var)
            if clipz==True:
                z =  torch.clamp(z,-clip_range,clip_range).float()
            fake = G(z)

            if improved == True:
                _, label =  D(fake)
            else:
                label = D(fake)
                    
            for p in params:
                if p.grad is not None:
                    p.grad.data.zero_()
            Iden_Loss = iden_loss(T,fake, iden, used_loss, criterion, fea_mean, fea_logvar, lam)

            if improved:
                Prior_Loss = torch.mean(F.softplus(log_sum_exp(label))) - torch.mean(log_sum_exp(label))
            else:
                Prior_Loss = - label.mean()

            Total_Loss = Prior_Loss + lamda * Iden_Loss
           
            Total_Loss.backward()
            solver.step()

            Prior_Loss_val = Prior_Loss.item()
            Iden_Loss_val = Iden_Loss.item()

            if (i+1) % 300 == 0:
                outputs_z = "{}_iter_{}_{}_dis".format(prefix, random_seed, i)
                outputs_label = "{}_iter_{}_{}_label".format(prefix, random_seed, i)
                np.save(outputs_z,{"mu":mu.detach().cpu().numpy(),"log_var":log_var.detach().cpu().numpy()})
                np.save(outputs_label,iden.detach().cpu().numpy())
        
                with torch.no_grad():
                    z = reparameterize(mu, log_var)
                    if clipz==True:
                        z =  torch.clamp(z,-clip_range, clip_range).float()
                    fake_img = G(z.detach())
                    eval_prob = E(low2high(fake_img))[-1]
                    
                    eval_iden = torch.argmax(eval_prob, dim=1).view(-1)
                    acc = iden.eq(eval_iden.long()).sum().item() * 100.0 / bs
                    save_tensor_images(fake_img, save_img_dir + '{}.png'.format(i+1))
                    print("Iteration:{}\tPrior Loss:{:.2f}\tIden Loss:{:.2f}\tAttack Acc:{:.2f}".format(i+1, Prior_Loss_val, Iden_Loss_val, acc))
                    
                        
        outputs_z = "{}_iter_{}_{}_dis".format(prefix, random_seed, iter_times)
        outputs_label = "{}_iter_{}_{}_label".format(prefix, random_seed, iter_times)
        np.save(outputs_z,{"mu":mu.detach().cpu().numpy(),"log_var":log_var.detach().cpu().numpy()})
        np.save(outputs_label,iden.detach().cpu().numpy())

In [20]:
import torchvision.utils as tvls
def save_tensor_images(images, filename, nrow = None, normalize = True):
    if not nrow:
        tvls.save_image(images, filename, normalize = normalize, padding=0)
    else:
        tvls.save_image(images, filename, normalize = normalize, nrow=nrow, padding=0)

class HLoss(nn.Module):
    def __init__(self):
        super(HLoss, self).__init__()

    def forward(self, x):
        b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
        b = -1.0 * b.sum()
        return b

# define "soft" cross-entropy with pytorch tensor operations
def softXEnt (input, target):
    targetprobs = nn.functional.softmax (target, dim = 1)
    logprobs = nn.functional.log_softmax (input, dim = 1)
    return  -(targetprobs * logprobs).sum() / input.shape[0]

def log_sum_exp(x, axis = 1):
    m = torch.max(x, dim = 1)[0]
    return m + torch.log(torch.sum(torch.exp(x - m.unsqueeze(1)), dim = axis))

def train_specific_gan():

    # Hyperparams
    file_path = None
    model_name_T = "VGG16_MRI"
    lr = 0.0002
    batch_size = 64
    z_dim = 100
    epochs = 25
    n_critic = 5
    dataset_name = "BraTS23"
    

    # Create save folders
    root_path = "/kaggle/working/attack_results"
    save_model_dir = os.path.join(root_path, os.path.join(dataset_name, model_name_T))
    save_img_dir = os.path.join(save_model_dir, "imgs")
    os.makedirs(save_model_dir, exist_ok=True)
    os.makedirs(save_img_dir, exist_ok=True)


    # Load target model
    T = get_augmodel()

    # Dataset
    data_dir = "/kaggle/input/preprocessed-brats23/Images"
    dataset, dataloader = init_dataloader(df=None,data_dir=data_dir,batch_size=batch_size)
    
    # Start Training
    print("Training GAN for %s" % model_name_T)

    G = Generator(z_dim)
    DG = MinibatchDiscriminator(n_classes = 2)
    
    G = torch.nn.DataParallel(G).cuda()
    DG = torch.nn.DataParallel(DG).cuda()

    dg_optimizer = torch.optim.Adam(DG.parameters(), lr=0.0002, betas=(0.5, 0.999))
    g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))

    entropy = HLoss()

    step = 0
    for epoch in range(epochs):
        start = time.time()
        _, unlabel_loader1 = init_dataloader(df=None,data_dir = data_dir,batch_size = batch_size, mode="gan",iterator=True)
        _,unlabel_loader2 = init_dataloader(df=None,data_dir =data_dir, batch_size =batch_size, mode="gan",iterator=True)
        for i, (imgs,label) in enumerate(dataloader):
            current_iter = epoch * len(dataloader) + i + 1

            step += 1
            imgs = imgs.cuda()
            bs = imgs.size(0)
            x_unlabel_t = next(unlabel_loader1)
            x_unlabel2_t = next(unlabel_loader2)
            
            freeze(G)
            unfreeze(DG)

            z = torch.randn(bs, z_dim).cuda()
            f_imgs = G(z)

            y_prob = T(imgs)[-1]
            y = torch.argmax(y_prob, dim=1).view(-1)
            
            x_unlabel = x_unlabel_t[0]
            x_unlabel2 = x_unlabel2_t[0]
            _, output_label = DG(imgs)
            _, output_unlabel = DG(x_unlabel)
            _, output_fake =  DG(f_imgs)

            loss_lab = softXEnt(output_label, y_prob)
            loss_unlab = 0.5*(torch.mean(F.softplus(log_sum_exp(output_unlabel)))-torch.mean(log_sum_exp(output_unlabel))+torch.mean(F.softplus(log_sum_exp(output_fake))))
            dg_loss = loss_lab + loss_unlab
            
            acc = torch.mean((output_label.max(1)[1] == y).float())
            
            dg_optimizer.zero_grad()
            dg_loss.backward()
            dg_optimizer.step()

            # train G
            if step % n_critic == 0:
                freeze(DG)
                unfreeze(G)
                z = torch.randn(bs, z_dim).cuda()
                f_imgs = G(z)
                mom_gen, output_fake = DG(f_imgs)
                mom_unlabel, _ = DG(x_unlabel2)

                mom_gen = torch.mean(mom_gen, dim = 0)
                mom_unlabel = torch.mean(mom_unlabel, dim = 0)

                Hloss = entropy(output_fake)
                g_loss = torch.mean((mom_gen - mom_unlabel).abs()) + 1e-4 * Hloss

                g_optimizer.zero_grad()
                g_loss.backward()
                g_optimizer.step()
#                 torch.cuda.empty_cache()

        end = time.time()
        interval = end - start
        
        print("Epoch:%d \tTime:%.2f\tG_loss:%.2f\t train_acc:%.2f" % (epoch, interval, g_loss, acc))

        torch.save({'state_dict':G.state_dict()}, os.path.join(save_model_dir, "improved_{}_G.pt".format(dataset_name)))
        torch.save({'state_dict':DG.state_dict()}, os.path.join(save_model_dir, "improved_{}_D.pt".format(dataset_name)))

        if (epoch+1) % 10 == 0:
            z = torch.randn(32, z_dim).cuda()
            fake_image = G(z)
            save_tensor_images(fake_image.detach(), os.path.join(save_img_dir, "improved_BraTS23_img_{}.png".format(epoch)), nrow = 8)

In [21]:
#main function of recovery.py
def recovery():
    improved_flag = True
    num_seeds = 1
    loss = 'logit_loss'
    # args.classid = '0,1,2,3'
    root_path = "/kaggle/working/attack_results"
    create_directory_if_not_exists(root_path)
    # Save dir
    prefix = os.path.join(root_path, "kedmi_300ids") 
    save_folder = os.path.join("{}_{}".format("BraTS23", "VGG16_MRI"), "L_Logit")
    prefix = os.path.join(prefix, save_folder)
    save_dir = os.path.join(prefix, "latent")
    save_img_dir = os.path.join(prefix, "imgs_{}".format("L_logit"))
    # args.log_path = os.path.join(prefix, "invertion_logs")

    os.makedirs(prefix, exist_ok=True)
    # os.makedirs(args.log_path, exist_ok=True)
    os.makedirs(save_img_dir, exist_ok=True)
    os.makedirs(save_dir, exist_ok=True)

    
    # Load models
    targetnets, E, G, D, n_classes, fea_mean, fea_logvar = get_attack_model()
    N = 5
    bs = 60
    

    # Begin attacking
    for i in range(1):
        iden = torch.from_numpy(np.arange(bs))

        # evaluate on the first 300 identities only
        target_cosines = 0
        eval_cosines = 0
        for idx in range(5):
            iden = iden %n_classes
            print("--------------------- Attack batch [%s]------------------------------" % idx)
            print('Iden:{}'.format(iden))
            save_dir_z = '{}/{}_{}'.format(save_dir,i,idx)
            
            if True:
                #KEDMI
                print('kedmi')

                dist_inversion(G, D, targetnets, E, iden,  
                                        lr=0.02, iter_times=2400,
                                        momentum=0.9, lamda=100,  
                                        clip_range=1, improved=improved_flag, 
                                        num_seeds=num_seeds, 
                                        used_loss=loss,
                                        prefix=save_dir_z,
                                        save_img_dir=os.path.join(save_img_dir, '{}_'.format(idx)),
                                        fea_mean=fea_mean,
                                        fea_logvar=fea_logvar,
                                        lam=1.0,
                                        clipz=True)
            iden = iden + bs 

In [22]:
# train_specific_gan()
# recovery()

In [23]:
from PIL import Image
import matplotlib.pyplot as plt
def show_saved_imgs(image_path):
#     image_path = "/kaggle/working/attack_results/BraTS23/VGG16_MRI/imgs/improved_BraTS23_img_19.png"
    # Open the image
    image = Image.open(image_path)
    plt.imshow(image)
    plt.axis('off')
    plt.show()

In [24]:
def test_gan():
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        G = Generator(100)
        G = torch.nn.DataParallel(G).to(device)
        root_path = "/kaggle/working/attack_results"
        dataset_name = "BraTS23"
        model_name_T = "VGG16_MRI"
        path = os.path.join(root_path, os.path.join(dataset_name, model_name_T))
        path_G = os.path.join(path, "improved_{}_G.tar".format(dataset_name))
        ckp_G = torch.load(path_G)
        G.load_state_dict(ckp_G['state_dict'], strict=True)
        
        G.eval()
        noise = torch.randn(1, 100)
        with torch.no_grad():
            generated_image = G(noise)
        generated_image = generated_image.squeeze(0).cpu().numpy()
        print(generated_image.shape)
        # Convert the generated image to a 2D array
        generated_image = np.squeeze(generated_image)  # Remove the channel dimension for grayscale

        # Plot the generated image
        plt.imshow(generated_image, cmap='gray')
        plt.axis('off')  # Turn off axis labels
        plt.show()
# test_gan()