In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
from lightning.pytorch.loggers import TensorBoardLogger
import sys
from lightning.pytorch.callbacks import StochasticWeightAveraging
import glob
import lightning as L
sys.path.insert(0, "/home/tordjx/DATA_DIR/config/projects/TRANSFERABLECLEANDATAPOISONING/lib/python/")
from customlib.dataloaders import CustomDataset
data_dir = '/home/tordjx/DATA_DIR/managed_folders/TRANSFERABLECLEANDATAPOISONING/LMc8Smw6/'
import timm
import numpy as np
import pandas as pd
import skimage.io
import matplotlib.pyplot as plt
import torchmetrics
import torch
torch.manual_seed(0)
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Resize
train_dataset = CustomDataset(data_dir, train=True, poison_dir = "/home/tordjx/DATA_DIR/managed_folders/TRANSFERABLECLEANDATAPOISONING/BqfvFGr8")
test_dataset = CustomDataset(data_dir, train=False)

# Create data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers = 19  )
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers = 19)


In [2]:

class Decoder(L.LightningModule):
    def __init__(self,encoders, intermediate_size = 128, nclasses= 43):
        super().__init__()
        feature_sizes = [(encoder(torch.rand(2,3,128,128))).shape[-1] for encoder in encoders ] 
        feature_sizes = list(set(feature_sizes))
        self.feature_sizes_dict = dict(zip(feature_sizes , [i for i in range(len(feature_sizes))]))
        self.intermediate_layers = nn.ModuleList([nn.Linear(f_size, intermediate_size) for f_size in feature_sizes])
        self.classif_head = nn.Linear(intermediate_size, nclasses)
        self.fake_detectors = nn.Linear(intermediate_size, 2)
    def forward(self,x) : 
        x = nn.functional.relu(self.intermediate_layers[self.feature_sizes_dict[x.shape[-1]]](x))
        x = nn.functional.relu(self.classif_head(x))
        return x
    def detect_fake(self,x) :
        x = nn.functional.relu(self.intermediate_layers[self.feature_sizes_dict[x.shape[-1]]](x))
        x = nn.functional.relu(self.fake_detectors(x))
        return x
class Discriminator(L.LightningModule):
    def __init__(self,encoder_names ,decoder_size = 128, nclasses=43, encoder_freeze = True):
        super().__init__()
        self.encoders = nn.ModuleList([
            timm.create_model(encoder_name, num_classes = 0 , pretrained = True) for encoder_name in encoder_names
        ])
        if encoder_freeze : 
            for param in self.encoders.parameters():
                param.requires_grad = False
        self.decoder  = Decoder(self.encoders)
    def forward(self,x):
        embeddings = [encoder(x) for encoder in self.encoders]
        return torch.stack([self.decoder(emb) for emb in embeddings]),torch.stack([self.decoder.detect_fake(emb) for emb in embeddings])


In [3]:
from monai.networks.nets import UNet,BasicUNetPlusPlus
class Generator(L.LightningModule):
    def __init__(self) : 
        super().__init__()
        """self.model = UNet(
                spatial_dims=2,
                in_channels=3,
                out_channels=3,
                channels=(4, 8, 16,32),
                strides=(2, 2,2),
                num_res_units=3)"""
        self.model = BasicUNetPlusPlus(spatial_dims =2, in_channels = 3, out_channels = 3)
    def forward(self,x) : 
        return self.model(x)[0]

    def generate_poisons(self, path = "/home/tordjx/DATA_DIR/managed_folders/TRANSFERABLECLEANDATAPOISONING/BqfvFGr8") : 
        files = glob.glob(os.path.join(path,"*"))
        for f in files:
            os.remove(f)
        train_dataset_names = CustomDataset(data_dir, train=True, return_names = True)
        train_loader_names = DataLoader(train_dataset_names, batch_size=batch_size, shuffle=True, num_workers = 19)
        for x,y , names in train_loader_names : 
            with torch.no_grad():
                perturb  = self.forward(x.cuda())
                perturb = perturb.cpu()
            for i in range(x.shape[0]):
                poisonned = (x[i]+perturb[i]).numpy()
                np.save(os.path.join(os.path.join(path),names[i].replace("ppm","npy")),poisonned)


2024-01-19 12:59:17.731726: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-01-19 12:59:17.732996: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-01-19 12:59:17.750361: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-19 12:59:17.750377: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-19 12:59:17.750390: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to regi

In [4]:
class GAN(L.LightningModule):
    def __init__(self,generator, discriminator,train_dataloader, val_dataloader,lr = 5e-4, alpha = 1000):
        super().__init__()
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.generator = generator
        self.discriminator = discriminator
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr
        self.alpha = alpha
        self.automatic_optimization = False

    def training_step(self,batch,batch_idx):
        x,y = batch
        optimizer_g, optimizer_d = self.optimizers()
        
        ##DISCRIMINATOR STEP
        with torch.no_grad():
            perturbation = self.generator(x)
        poisonned_image = perturbation + x
        predictions, predicted_fakeness = self.discriminator(poisonned_image)
        fakeness = torch.ones_like(y).cuda()
        discriminator_loss_poisoned = torch.mean(torch.stack([self.criterion(pred,y) for pred in predictions]))
        discriminator_loss_poisoned =discriminator_loss_poisoned+ torch.mean(torch.stack([self.criterion(pred,fakeness) for pred in predicted_fakeness]))
        predictions, predicted_fakeness= self.discriminator(x)
        fakeness = torch.zeros_like(y).cuda()

        discriminator_loss_real = torch.mean(torch.stack([self.criterion(pred,y) for pred in predictions]))
        discriminator_loss_real =discriminator_loss_real+ torch.mean(torch.stack([self.criterion(pred,fakeness) for pred in predicted_fakeness]))
        discriminator_loss = discriminator_loss_poisoned+discriminator_loss_real
        optimizer_d.zero_grad()
        self.manual_backward(discriminator_loss)
        optimizer_d.step()
        self.log("discriminator_loss_real", discriminator_loss_real)
        self.log("discriminator_loss_poisoned", discriminator_loss_poisoned)
        self.log("discriminator_loss", discriminator_loss)
        ##GENERATOR STEP
        perturbation = self.generator(x)
        poisonned_image = perturbation + x
        for param in self.discriminator.decoder.parameters():
            param.requires_grad = False
        predictions , predicted_fakeness= self.discriminator(poisonned_image)
        fakeness = torch.ones_like(y).cuda()
        discriminator_loss = torch.mean(torch.stack([self.criterion(pred,y) for pred in predictions]))
        fake_detector_loss_gen = torch.mean(torch.stack([self.criterion(pred,fakeness) for pred in predicted_fakeness]))
        discriminator_loss =discriminator_loss + fake_detector_loss_gen
        L2_pen = nn.MSELoss()(perturbation, torch.zeros_like(perturbation))
        generator_loss = - discriminator_loss + self.alpha *L2_pen 
        for param in self.discriminator.decoder.parameters():
            param.requires_grad = True
        optimizer_g.zero_grad()
        self.manual_backward(generator_loss, retain_graph = True)
        optimizer_g.step()
        self.log("alpha*L2_pen", self.alpha*L2_pen)
        self.log("generator_loss", generator_loss)
        self.log("fake_detector_loss_gen", fake_detector_loss_gen)

    def configure_optimizers(self):
        optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=self.lr, weight_decay = 1e-4)
        optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, weight_decay = 1e-4)
        return [optimizer_G, optimizer_D], []

    def on_train_epoch_start(self):
        #Log sample image to tensorboard
        image,y = next(iter(self.val_dataloader))
        with torch.no_grad() : 
            perturb = self.generator(image.cuda()).cpu() 
        idx = 0
        self.logger.experiment.add_image("Full perturbation",perturb[idx].moveaxis(0,-1),self.current_epoch,dataformats="HWC")
        self.logger.experiment.add_image("Red perturbation",perturb[idx][0],self.current_epoch,dataformats="HW")
        self.logger.experiment.add_image("Green perturbation",perturb[idx][1],self.current_epoch,dataformats="HW")
        self.logger.experiment.add_image("Blue perturbation",perturb[idx][2],self.current_epoch,dataformats="HW")
        self.logger.experiment.add_image("Clean image",image[idx].moveaxis(0,-1),self.current_epoch,dataformats="HWC")
        self.logger.experiment.add_image("Poisoned image",perturb[idx].moveaxis(0,-1)+image[idx].moveaxis(0,-1),self.current_epoch,dataformats="HWC")

In [None]:
os.environ["TENSORBOARD_BINARY"]="/home/tordjx/DATA_DIR/code-envs/python/rapids/bin/tensorboard"
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/
log_name = "5 backbones"
logger = TensorBoardLogger("lightning_logs/", name=log_name)
trainer = L.Trainer(max_epochs=15, logger = logger)
encoder_names = ['resnet34' , "resnest26d","efficientnet_b0","regnetx_006","densenet121"]
generator = Generator()
discriminator = Discriminator(encoder_names)
gan = GAN(generator, discriminator, train_loader, test_loader)
trainer.fit(gan, train_loader, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


BasicUNetPlusPlus features: (32, 32, 64, 128, 256, 32).


/home/tordjx/DATA_DIR/code-envs/python/rapids/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:72: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | generator     | Generator        | 2.4 M 
1 | discriminator | Discriminator    | 53.6 M
2 | criterion     | CrossEntropyLoss | 0     
---------------------------------------------------
3.1 M     Trainable params
52.9 M    Non-trainable params
56.0 M    Total params
224.165   Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

In [None]:
image,y = next(iter(train_loader))
with torch.no_grad() : 
    perturb = gan.generator.cuda()(image.cuda()).cpu() 

for idx in range(3):
    
    fig, ax = plt.subplots(1,6, figsize = (30,5))
    ax[0].imshow(perturb[idx].moveaxis(0,-1))
    ax[0].set_title("Full perturbation")
    ax[1].imshow(perturb[idx][0])
    ax[1].set_title("Red perturbation")
    ax[2].imshow(perturb[idx][1])
    ax[2].set_title("Green perturbation")
    ax[3].imshow(perturb[idx][2])
    ax[3].set_title("Blue perturbation")
    ax[4].imshow(image[idx].moveaxis(0,-1))
    ax[4].set_title("Clean image")
    ax[5].imshow(perturb[idx].moveaxis(0,-1)+image[idx].moveaxis(0,-1))
    ax[5].set_title("Image + perturbation")
    plt.show()


In [None]:
gan.generator.generate_poisons()