In [1]:
import sys
sys.path.append('unet_backbones')
from unet_backbones.backbones_unet.model.unet import Unet
from unet_backbones.backbones_unet.utils.dataset import SemanticSegmentationDataset
from unet_backbones.backbones_unet.model.losses import DiceLoss
from unet_backbones.backbones_unet.utils.trainer import Trainer
from unet_backbones.backbones_unet.utils.reproducibility import set_seed
from unet_backbones.datasets import VOCSegmentation
from unet_backbones.utils import ext_transforms as et

import os
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import Normalize
from torch.utils.data import DataLoader
from unet_backbones.metrics import StreamSegMetrics

import json
import numpy as np

In [2]:
seed = 42
voc_data_root = "datasets/VOC2012"
voc_crop_size = 256
voc_num_classes = 21
small_decoder = True

save = "experiments"
dataset = "pascalvoc2012"
encoder = "convnext_tiny"
epochs = 1
lr = 1e-4

save_path = os.path.join(save, dataset, encoder, "small_deocder_" + str(small_decoder), "epochs_" + str(epochs), "lr_"+str(lr), 'seed_'+str(seed))
model_path = os.path.join(save_path , "model")
json_path = os.path.join(save_path, "losses.json")
attack_save_path = os.path.join(save_path , "attack")

os.makedirs(save_path, exist_ok=True)
os.makedirs(model_path, exist_ok=True)
os.makedirs(attack_save_path, exist_ok=True)

In [3]:
def get_logger(save_folder):
    log_path = str(save_folder) + '/log.log'
    logging.basicConfig(filename=log_path, filemode='a')
    logger_name = "main-logger"
    logger = logging.getLogger(logger_name)
    logger.setLevel(logging.INFO)
    handler = logging.StreamHandler()
    fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s"
    handler.setFormatter(logging.Formatter(fmt))
    logger.addHandler(handler)
    return logger

In [4]:
set_seed(seed)
logger = get_logger(save_path)

> SEEDING DONE


In [5]:
class Args:
    def __init__(self, epsilon, mode, alpha, iterations, attack, num_classes, norm, targeted, save_attack, save_attack_path):
        self.epsilon = epsilon                 # epsilon for adversarial attack
        self.mode = mode                       # choices={'adv_attack', 'adv_train', 'train', 'test'}
        self.alpha = alpha                     # alpha for adversarial attack
        self.iterations = iterations           # number of iterations for adversarial attack
        self.attack = attack                   # choices={'fgsm', 'cospgd', 'segpgd', 'pgd'}
        self.num_classes = num_classes
        self.norm = norm                       # choices={'inf', 'two', 'one'}
        self.targeted = targeted               # choices={'False', 'True'}
        self.save_attack = save_attack         # save the perturbed image
        self.save_attack_path = save_attack_path  # the path for saving the perturbed image

In [6]:
val_transform = et.ExtCompose([
                et.ExtResize(voc_crop_size),
                et.ExtCenterCrop(voc_crop_size),
                et.ExtToTensor(),
            ])

test_dst = VOCSegmentation(root = voc_data_root, 
                        year = "2012",
                        image_set='test', 
                        download = True, 
                        transform = val_transform)

Using downloaded and verified file: datasets/VOC2012/VOCtrainval_11-May-2012.tar


In [7]:
train_loader = None
test_loader = DataLoader(test_dst, batch_size=1, num_workers=1, shuffle=False)

In [8]:
model = Unet(
    backbone = encoder,
    small_decoder = small_decoder,
    in_channels = 3,
    num_classes = voc_num_classes,
)

In [9]:
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
model = nn.Sequential(Normalize(mean = mean, std = std), model)

In [10]:
path = "experiments/pascalvoc2012/convnext_tiny/small_deocder_True/epochs_150/lr_0.0001/seed_42/model/final_model.pt"
checkpoint = torch.load(path, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])

<All keys matched successfully>

In [14]:
targeted = False

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(params, lr=lr)

criterion = {"cross_entropy": nn.CrossEntropyLoss(ignore_index=255, reduction="none"), "dice_loss": DiceLoss(reduction=None)} 

metrics = StreamSegMetrics(voc_num_classes)
actual_metrics = StreamSegMetrics(voc_num_classes) if targeted else None
initial_metrics = StreamSegMetrics(voc_num_classes) if targeted else None

args = Args(
    epsilon=0.01, 
    mode="adv_attack", 
    alpha=8/255, 
    iterations=500, 
    attack="cospgd", 
    num_classes=voc_num_classes, 
    norm="inf", 
    targeted=targeted,
    save_attack=True,
    save_attack_path = attack_save_path
)

In [15]:
trainer = Trainer(
        model,
        criterion = criterion["cross_entropy"],
        optimizer=optimizer,
        epochs=1,
        metrics = metrics,
        actual_metrics = actual_metrics,
        initial_metrics = initial_metrics,
        logger = logger,
        model_save_path = os.path.join(model_path, "best_model.pt"),
        args=args
    )

In [16]:
trainer.fit(train_loader, test_loader)

Traning Model on 1 epochs:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ? validating-batch/s]

[2024-11-08 15:43:20,610 INFO trainer.py line 308 10790] iteration:0---0.6407607197761536
[2024-11-08 15:43:20,655 INFO trainer.py line 308 10790] iteration:1---1.4699273109436035
[2024-11-08 15:43:20,705 INFO trainer.py line 308 10790] iteration:2---2.4529876708984375
[2024-11-08 15:43:20,756 INFO trainer.py line 308 10790] iteration:3---1.6162222623825073
[2024-11-08 15:43:20,808 INFO trainer.py line 308 10790] iteration:4---2.271636962890625
[2024-11-08 15:43:20,856 INFO trainer.py line 308 10790] iteration:5---1.7584843635559082
[2024-11-08 15:43:20,907 INFO trainer.py line 308 10790] iteration:6---1.6042613983154297
[2024-11-08 15:43:20,958 INFO trainer.py line 308 10790] iteration:7---1.308849573135376
[2024-11-08 15:43:21,010 INFO trainer.py line 308 10790] iteration:8---2.193924903869629
[2024-11-08 15:43:21,063 INFO trainer.py line 308 10790] iteration:9---2.3155598640441895
[2024-11-08 15:43:21,116 INFO trainer.py line 308 10790] iteration:10---1.9236608743667603
[2024-11-08 