In [15]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision import datasets
from torch.optim import Adam
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader,  Subset
from tqdm import tqdm
import os
import numpy as np
from PIL import Image, ImageOps
import staintools
from statistics import median
from tqdm import tqdm
import random
import argparse
import json
import copy 
from types import SimpleNamespace
import models.tent as tent
from models.delta import DELTA
from models.test_utils import LAME
import yaml
import torch.optim as optim
import models.sar as sar
from models.sam import SAM
import models.tent as TENT1
import math



In [16]:

def custom_transform(image_path):
    image = staintools.read_image(image_path)
    image = staintools.LuminosityStandardizer.standardize(image)
    normalized_image = stain_norm.transform(image)
    im_pil = Image.fromarray(normalized_image.astype('uint8'), 'RGB')  # Convert to PIL image for transforms
    transform = transforms.Compose([
            transforms.Resize((m_p_s, m_p_s)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    return transform(im_pil)

# Define a custom dataset class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.dataset = datasets.ImageFolder(root=self.root)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img_path, label = self.dataset.imgs[idx]
        if self.transform:
            img = self.transform(img_path)
        return img, label



def setup_tent(model,steps,episodic,METHOD,noaffine=False):
    """Set up tent adaptation.

    Configure the model for training + feature modulation by batch statistics,
    collect the parameters for feature modulation by gradient optimization,
    set up the optimizer, and then tent the model.
    """
    model = TENT1.configure_model(model,noaffine=noaffine)
    params, param_names = TENT1.collect_params(model)
    if METHOD=="SGD":
        optimizer = optim.SGD(params,
                        lr=0.00025,
                        momentum=0.9
                        )
    elif METHOD=="Adam":
        optimizer = optim.Adam(params,
                    lr=1e-3,
                    betas=(0.9, 0.999),
                    weight_decay=0)

    tent_model = TENT1.Tent(model, optimizer,
                           steps=steps,#cfg.OPTIM.STEPS
                           episodic=episodic,
                          noaffine=noaffine)#cfg.MODEL.EPISODIC
    return tent_model

        
def eval_TTA(methods,adapt_model,chosen_loader):
#     with torch.no_grad():

        correct = {}
        total = {}
        batch_acc = {}
        class_num = {}
        class_correct = {}
        class_avg_acc = {}
        cumulative_acc = {}
        for method in methods:
            correct[method] = 0
            total[method] = 0   
            batch_acc[method] = 0
            class_num[method] = np.array([0]*10)
            class_correct[method] = np.array([0]*10)
            class_avg_acc[method] = 0
            cumulative_acc[method] = 0
        results={}
        for method in methods:
            batch_idx=0
            iters = iter(chosen_loader)
            print(method)
            while(batch_idx<10):
                print(batch_idx)
                try:
                    inputs, targets = next(iters)
                except:
                    print("empty mask")
                    continue
                inputs, targets = inputs.cuda(), targets.cuda()
                batch_result={}
                model=adapt_model[method]
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total[method] = targets.size(0)
                correct[method] = predicted.eq(targets).sum().item()
                batch_acc[method] = 100.*correct[method]/total[method]
                for i, t in enumerate(targets):
                    class_num[method][t.item()] += 1
                    class_correct[method][t.item()] += (predicted[i]==t)
                acc = (class_correct[method][class_num[method]!=0] / class_num[method][class_num[method]!=0])
                class_avg_acc[method] = acc.mean() * 100.
                cumulative_acc[method] = class_correct[method].sum() / class_num[method].sum() * 100
                batch_result[f'{method}_cumulative_accuracy_{batch_idx}'] = cumulative_acc[method]
                batch_result[f'{method}_batch_accuracy_{batch_idx}'] = batch_acc[method]
                batch_result[f'{method}_class_accuracy_{batch_idx}'] = class_avg_acc[method]

                results.update(batch_result)
                tentnet.reset()
                batch_idx+=1
        return results


In [17]:
class Args:
    def __init__(self):
        self.cor_path="02_training_native"
        self.model_name="TvN_350_SN_D256_Initial_Ep7_fullmodel.pth"
        self.exp_type="imbalanced_experiments"
        self.batch_size=64

In [18]:
args=Args()

In [19]:

# parser = argparse.ArgumentParser(description='pathology TTA')
# parser.add_argument('--artifact', type=str)
# parser.add_argument('--cor_path', type=str)
# parser.add_argument('--exp_type', default = 'imbalanced_experiments', type=str)
# parser.add_argument('--rho', required=False, type=float, help='long tail factor')
# parser.add_argument('--pi', required=False, type=float, help='dirichlet factor')
# parser.add_argument('--model_name', default = 'TvN_350_SN_D256_Initial_Ep7_fullmodel.pth', type=str)
# args = parser.parse_args()

parameters={}
corrupt_data_path=args.cor_path
model_name = args.model_name

# print(artifact)
# print(os.getcwd())
st = staintools.read_image("./Artifact/15_stain_scheme/schemes_ready/standard_he_stain_small.jpg")
standardizer = staintools.LuminosityStandardizer.standardize(st)
stain_norm = staintools.StainNormalizer(method='macenko')
stain_norm.fit(st)
number_of_classes=3
model_dir = 'Models'
m_p_s = 350



In [34]:
def sample_data_dist(args,labels):
    len_dataset =len(labels)
    exp_setup=args.exp_setup
    if exp_setup == "IS+CB":
        idx = [i for i in range(len_dataset)]
        random.shuffle(idx)
            
    elif exp_setup == "IS+CI":
        idx = dirichlet_split_noniid(np.array(labels), args.pi, 10)
        idx = np.concatenate(idx)
    elif exp_setup == "DS+CB":
        
        prob_per_class = []
        for cls_idx in range(number_of_classes):
            prob_per_class.append( args.rho ** (cls_idx / (number_of_classes - 1.0)) )
        prob_per_class = np.array(prob_per_class) / sum(prob_per_class)
        img_per_class = prob_per_class * len(labels)
        idx = []
        y_test_np = np.array(labels)
        for c, num in enumerate(img_per_class):
            all_of_c = np.where(y_test_np==c)[0]
            idx.append(np.random.choice(all_of_c, int(num)+1))
        idx = np.concatenate(idx)
        random.shuffle(idx)
        
    elif exp_setup == "DS+CI":
        
        prob_per_class = []
        for cls_idx in range(number_of_classes):
            prob_per_class.append( args.rho ** (cls_idx / (number_of_classes - 1.0)) )
        prob_per_class = np.array(prob_per_class) / sum(prob_per_class)
        img_per_class = prob_per_class * len(labels)
        idx = []
        y_test_np = np.array(labels)
        for c, num in enumerate(img_per_class):
            all_of_c = np.where(y_test_np==c)[0]
            idx.append(np.random.choice(all_of_c, int(num)+1))
        idx = np.concatenate(idx)
        idx2 = dirichlet_split_noniid(np.array([y_test_np[i] for i in idx]), args.pi, 10)
        idx = np.concatenate([idx[i] for i in idx2])
        
    return idx
        

In [35]:
experiments=[
 {"exp_setup":"IS+CB"},
 {"exp_setup":"DS+CB", "rho":1},
 {"exp_setup":"DS+CB", "rho":0.5},
 {"exp_setup":"DS+CB", "rho":0.1},
 {"exp_setup":"IS+CI", "pi":0.1},
 {"exp_setup":"IS+CI", "pi":0.05},
 {"exp_setup":"DS+CI", "rho":0.5, "pi":0.1},
 {"exp_setup":"DS+CI", "rho":0.5, "pi":0.05},
]

In [36]:
class Exp:
    def __init__(self, exp):
        self.exp_setup=exp["exp_setup"]
        self.rho=exp["rho"] if "rho" in exp.keys() else None
        self.pi=exp["pi"] if "pi" in exp.keys() else None
        self.setup_name=f"{self.exp_setup}"
        if self.rho:
            self.setup_name+=f"_rho{self.rho}"
        if self.pi:
            self.setup_name+=f"_pi{self.pi}"

In [37]:
path= f"TTA_on_corrupted/{args.exp_type}"

In [None]:
for exp in experiments:
    experiment_args=Exp(exp)
    setup=""
    parameters=exp
    json_data=[]
    for seed in  [0,19,22, 42, 81]: #2020 , 42, 81
        parameters["seed"]=seed
        parameters["model"]=args.model_name
        
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        with open('configs/cifar_tentdelta_adam.yaml', 'rb') as f:
            args_tent = yaml.safe_load(f.read())
        config_obj = SimpleNamespace(**args_tent)

        path_model = os.path.join(model_dir, model_name)
        net = torch.load(path_model) 
        net = net.cuda()
        net.eval()
        lame_model=LAME(copy.deepcopy(net),3,5,1)
        delta_model= DELTA(config_obj,copy.deepcopy(net))
        sarnet = sar.configure_model(copy.deepcopy(net))
        params, param_names = sar.collect_params(sarnet)
        base_optimizer = torch.optim.SGD
        optimizer = SAM(params, base_optimizer, lr=0.00025, momentum=0.9) #lr suitable for batch size >32
        sar_model = sar.SAR(sarnet, optimizer, margin_e0=math.log(3)*0.40) # since we have 3 classes

        tentnet = setup_tent(copy.deepcopy(net),10,False,METHOD="Adam")
        TENT1.check_model(tentnet)


        methods=["TENT", "DELTA","NOT_ADAPTED","SAR","LAME"] #"TTN" "DELTA","NOT_ADAPTED", "LAME", "DELTA","NOT_ADAPTED", "LAME","SAR"
        adapt_model={"DELTA":delta_model,"TENT":tentnet,"NOT_ADAPTED":net,"LAME":lame_model,"SAR":sar_model} #"TTN":norm_net_TTN, 
        custom_dataset = CustomDataset(root=f"Corrupted_data/{corrupt_data_path}/00_original", transform=custom_transform)
        labels = custom_dataset.dataset.targets
        idx = sample_data_dist(experiment_args,labels)
        subset = Subset(custom_dataset, idx)
        dataloader = DataLoader(subset, batch_size=args.batch_size, shuffle=False)

        correct = {}
        total = {}
        batch_acc = {}
        class_num = {}
        class_correct = {}
        class_avg_acc = {}
        cumulative_acc = {}
        for method in methods:
            correct[method] = 0
            total[method] = 0   
            batch_acc[method] = 0
            class_num[method] = np.array([0]*3)
            class_correct[method] = np.array([0]*3)
            class_avg_acc[method] = 0
            cumulative_acc[method] = 0
        results={}
        for method in methods:
            batch_idx=0
            iters = iter(dataloader)
            while(batch_idx<10):
                print(batch_idx)
                try:
                    inputs, targets = next(iters)
                except:
                    print("empty mask")
                    continue
                inputs, targets = inputs.cuda(), targets.cuda()
                batch_result={}
                outputs=adapt_model[method](inputs)
                _, predicted = outputs.max(1)
                total[method] = targets.size(0)
                correct[method] = predicted.eq(targets).sum().item()
                batch_acc[method] = 100.*correct[method]/total[method]
                for i, t in enumerate(targets):
                    class_num[method][t.item()] += 1
                    class_correct[method][t.item()] += (predicted[i]==t)
                acc = (class_correct[method][class_num[method]!=0] / class_num[method][class_num[method]!=0])
                class_avg_acc[method] = acc.mean() * 100.
                cumulative_acc[method] = class_correct[method].sum() / class_num[method].sum() * 100
                batch_result[f'{method}_cumulative_accuracy_{batch_idx}'] = cumulative_acc[method]
                batch_result[f'{method}_batch_accuracy_{batch_idx}'] = batch_acc[method]
                batch_result[f'{method}_class_accuracy_{batch_idx}'] = class_avg_acc[method]

                results.update(batch_result)
                adapt_model["TENT"].reset()
                batch_idx+=1

        json_entry = {"parameters": parameters,
                "results": results
                }
        json_data.append(json_entry)


    if not os.path.exists(path):
        os.makedirs(path)
    with open(f"{path}/results_TTA_{setup}_{model_name[:-4]}.json", 'w') as json_file:
        json.dump(json_data, json_file, indent=4, separators=(',',': '))



here
#Trainable/total parameters: 0/24033347 	 Fraction: 0.00% 
0
1
2
3
0
1
2
3
0
1
2
3
0
1
2
3
0
1
2
3
here
#Trainable/total parameters: 0/24033347 	 Fraction: 0.00% 
0
1
2
3
0
1
2
3
0
1
2
3
0
1
2
3
0
1
2
3
here
#Trainable/total parameters: 0/24033347 	 Fraction: 0.00% 
0
1
2
3
0
1
2
3
0
1
2
3
0
1
2
3
0
1
2
3
here
#Trainable/total parameters: 0/24033347 	 Fraction: 0.00% 
0
1
2
3
0
1
2
3
0
1
2
3
0
1
2
3
0
1
2
3
here
#Trainable/total parameters: 0/24033347 	 Fraction: 0.00% 
0
1
2
3
0
1
2
3
0
1
2
3
0
1
2
3
0
1
2
3
here
#Trainable/total parameters: 0/24033347 	 Fraction: 0.00% 
0
1
2
3
0
1
2
3
0
1
2
3
0
1
2
3
0
1
2
3
here
#Trainable/total parameters: 0/24033347 	 Fraction: 0.00% 
0
1
2
3
empty mask
3
empty mask
3
empty mask
3
empty mask
3
empty mask
3
empty mask
3
empty mask
3
empty mask
3
empty mask
3
empty mask
3
empty mask
3
