In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import glob, os
from copy import deepcopy
from tqdm import tqdm

In [2]:
from PIL import Image
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchvision import transforms, models, datasets
from torch.utils.data import Dataset, DataLoader, random_split

from custom_utils.custom_utils import *
from custom_utils.custompnn_utils import *
from custom_utils.dataloader_mod import *
from custom_utils.tent_mod import *
from custom_utils.dataloader_TTA import *

In [3]:
print('GPU availability : ', torch.cuda.is_available())
if torch.cuda.is_available():
    device = torch.device('cuda:4')
else:
    device = torch.device("cpu")

GPU availability :  True


In [4]:
random_seed = 2025

import random
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU

In [5]:
from sklearn.metrics import f1_score, accuracy_score

In [59]:
mean_ = [0.7682, 0.4299, 0.4733]
std_ = [0.2421, 0.2967, 0.2483]

transform = transforms.Compose([
                transforms.Resize((64,64)),
                transforms.ToTensor(),
             transforms.Normalize(mean=mean_,
                              std=std_)
            ])

dataset = DigitData('.', size=64, transform=transform)
#dataset, _ = random_split(dataset, [0.1, 0.9])
data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
len(data_loader)

99

## Tent

In [60]:
class ResidualAdapter(nn.Module):
    def __init__(self, dim=512, h_dim=64):
        super().__init__()
        self.down = nn.Linear(dim, h_dim)
        self.relu = nn.ReLU()
        self.up = nn.Linear(h_dim, dim)
        
        nn.init.zeros_(self.up.weight)
        nn.init.zeros_(self.up.bias)
    def forward(self, x):
        #print(self.up(self.relu(self.down(x))).sum())
        return x + self.up(self.relu(self.down(x)))

In [61]:
def softmax_entropy_(output, multi = True):
    """Entropy of softmax distribution from logits."""
    temprature = 1
    output = output/temprature
    if multi:
        p = output.softmax(dim = -1).mean(dim=0).clamp(min=1e-12)
    else:
        p = output.softmax(dim = -1).clamp(min=1e-12)
    # p: bsz, 10 -> output: bsz
    entropy = -(p * torch.log(p)).sum(1)
    return entropy

In [62]:

class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone=None, n_way=10, normalize=True, proto=None, adapt=False):
        super(PrototypicalNetworks, self).__init__()
        if backbone:
            self.backbone = backbone
        else:
            resnet_pt = models.resnet18(weights="ResNet18_Weights.IMAGENET1K_V1")
            resnet_pt.fc = nn.Flatten()
            self.backbone = resnet_pt
        self.n_way = n_way
        self.normalize = normalize
        self.proto = proto
        self.adapt = adapt
        if self.adapt:
            self.adapter = ResidualAdapter()
        
    def forward(self, images, dist=None):
        proto = self.proto
        if self.adapt:
            proto = self.adapter(self.proto)
        
        z = F.normalize(self.backbone.forward(images))
        if dist == None:
            dists = self.euclidean_distance(z, proto)  # [Q, N]
        else:
            dists = dist(x, y)
        return -dists
    

    def euclidean_distance(self, x, y):
        n = x.shape[0]  # Q
        m = y.shape[0]  # N
        d = x.shape[1]
        assert d == y.shape[1]

        # x -> [Q, 1, D], y -> [1, N, D]
        x = x.unsqueeze(1).expand(n, m, d)
        y = y.unsqueeze(0).expand(n, m, d)

        return torch.pow(x - y, 2).sum(2)
    
    def GetProto(self):
        return self.proto
    
    def update_proto(self, images, pred, momentum = 0.9):
        z = F.normalize(self.backbone.forward(images)) # bsz, emb_dim
        count_list = torch.tensor([(pred==label).sum() for label in range(10)]).to(self.proto.device) # list of len 10
        z_proto = torch.cat([
            nn.functional.normalize(z[torch.nonzero(pred == label)].mean(0)) if count_list[label]!=0 else torch.zeros(1,z.shape[-1]).to(self.proto.device) for label in range(self.n_way)
        ]).to(self.proto.device)
        momentum_ = count_list*(momentum/count_list.sum()).to(self.proto.device) # list of len 10
        proto_new = self.proto*(1-momentum_).unsqueeze(1) + z_proto*momentum_.unsqueeze(1)
        self.proto = F.normalize(proto_new)
        return None
    

In [64]:
class Tent_PNN(nn.Module):
    """Tent adapts a model by entropy minimization during testing.

    Once tented, a model adapts itself by updating on every forward.
    """
    def __init__(self, model, optimizer, steps=1, episodic=False):
        super().__init__()
        self.model = model
        self.optimizer = optimizer
        self.steps = steps
        assert steps > 0, "tent requires >= 1 step(s) to forward and update"
        self.episodic = episodic

        # note: if the model is never reset, like for continual adaptation,
        # then skipping the state copy would save memory
        self.model_state, self.optimizer_state = \
            copy_model_and_optimizer(self.model, self.optimizer)

    def forward(self, x):
        if self.episodic:
            self.reset()

        for _ in range(self.steps):
            outputs = forward_and_adapt_(x, self.model, self.optimizer)

        return outputs

    def reset(self):
        if self.model_state is None or self.optimizer_state is None:
            raise Exception("cannot reset without saved model/optimizer state")
        load_model_and_optimizer(self.model, self.optimizer,
                                 self.model_state, self.optimizer_state)
        
def copy_model_and_optimizer(model, optimizer):
    """Copy the model and optimizer states for resetting after adaptation."""
    model_state = deepcopy(model.state_dict())
    optimizer_state = deepcopy(optimizer.state_dict())
    return model_state, optimizer_state

In [94]:
def forward_and_adapt_(x, model, optimizer):
    """Forward and adapt model on batch of data.

    Measure entropy of the model prediction, take gradients, and update params.
    """
    # forward
    outputs = model(x)
    # adapt
    loss = softmax_entropy(outputs).mean(0)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    # update prototype
    with torch.no_grad():
        outputs = model(x)
        log_p_y = F.log_softmax(outputs, dim = 1)
        tmp_pred = torch.max(log_p_y.data, axis = 1)[1]
        model.update_proto(x, tmp_pred)
    # return outputs
    return model(x)

In [646]:
class Tent_PNN2(nn.Module):
    """Tent adapts a model by entropy minimization during testing.

    Once tented, a model adapts itself by updating on every forward.
    """
    def __init__(self, model, optimizer, optimizer2, steps=1, episodic=False, adapt_step=3):
        super().__init__()
        self.model = model
        self.optimizer = optimizer
        self.optimizer2 = optimizer
        self.steps = steps
        assert steps > 0, "tent requires >= 1 step(s) to forward and update"
        self.episodic = episodic
        self.adapt_step = adapt_step

        # note: if the model is never reset, like for continual adaptation,
        # then skipping the state copy would save memory
        self.model_state, self.optimizer_state = \
            copy_model_and_optimizer(self.model, self.optimizer)

    def forward(self, x):
        if self.episodic:
            self.reset()

        for idx in range(self.steps):
            outputs = forward_and_adapt_2(x, self.model, self.optimizer, self.optimizer2, self.adapt_step, idx)

        return outputs

    def reset(self):
        if self.model_state is None or self.optimizer_state is None:
            raise Exception("cannot reset without saved model/optimizer state")
        load_model_and_optimizer(self.model, self.optimizer,
                                 self.model_state, self.optimizer_state)
        
        
def forward_and_adapt_2(x, model, optimizer, optimizer2, adapt_step, idx):
    """Forward and adapt model on batch of data.

    Measure entropy of the model prediction, take gradients, and update params.
    """
    # forward
    # adapt
    if idx == 0:
        for _ in range(adapt_step):
            optimizer2.zero_grad()
            outputs = model(x)
            loss = softmax_entropy(outputs).mean(0)
            
            loss.backward()
            optimizer2.step()

    optimizer.zero_grad()
    outputs = model(x)
    z = F.normalize(model.backbone.forward(x))
    
    loss = softmax_entropy(outputs).mean(0)
    loss.backward()
    optimizer.step()

    if idx == 0:
        # update prototype
        with torch.no_grad():
            outputs = model(x)
            log_p_y = F.log_softmax(outputs, dim = 1)
            tmp_pred = torch.max(log_p_y.data, axis = 1)[1]
            #model.update_proto(x, tmp_pred)
        # return outputs
    return model(x)

In [622]:
for image, y in data_loader:
    print(image.shape)
    image = image.to(device)
    output = model.backbone.forward(image)
    print(output.shape)
    break

torch.Size([64, 3, 64, 64])
torch.Size([64, 512])


In [600]:
import custom_utils.tent_mod as tent

In [648]:
model = PrototypicalNetworks(proto=torch.load('model/20250428_proto.pt'), adapt=True)

model.load_state_dict(torch.load('model/20250428_protonet.pt'), strict=False)

_IncompatibleKeys(missing_keys=['adapter.down.weight', 'adapter.down.bias', 'adapter.up.weight', 'adapter.up.bias'], unexpected_keys=[])

In [649]:
model = custom_configure_model(model)
params, param_names = custom_collect_params(model)
optimizer = optim.Adam(params, lr=2.5e-4)
optimizer2 = optim.Adam(model.adapter.parameters(), lr=1e-3)
tented_model = Tent_PNN2(model, optimizer, optimizer2, steps=1, episodic=False, adapt_step=5)

In [650]:
# for name, param in tented_model.named_parameters():
#     print(name, param.requires_grad)

In [651]:
tented_model = tented_model.to(device)
tented_model.model.proto = tented_model.model.proto.to(device)

In [652]:
preds, targets = test_tent_PNN(data_loader, tented_model)

Acc: 49.59%
F1 : 48.43%


In [645]:
for _ in range(10):
    preds, targets = test_tent_PNN(data_loader, tented_model)


Acc: 53.59%
F1 : 52.83%
Acc: 59.70%
F1 : 59.41%
Acc: 63.52%
F1 : 63.39%
Acc: 66.38%
F1 : 66.20%
Acc: 66.79%
F1 : 66.76%
Acc: 66.70%
F1 : 66.83%
Acc: 68.70%
F1 : 68.69%
Acc: 69.62%
F1 : 69.60%
Acc: 70.16%
F1 : 70.19%
Acc: 71.23%
F1 : 71.37%


In [639]:
# + adapter, 5 steps, adapt_step=10, lr=2.5e-4, 1e-3
ac = np.array([56.62, 50.49, 55.94, 1, 1])
f1 = np.array([56.33, 49.63, 55.41, 1, 1])

ac.mean(), ac.std(), f1.mean(), f1.std()

(33.010000000000005, 26.22230958554185, 32.674, 25.963522565322297)

In [384]:
# + adapter, 1 steps, adapt_step=10, lr=2.5e-4, 1e-3
ac = np.array([53.73, 55.91, 53.43, 55.09, 58.59])
f1 = np.array([53.17, 55.84, 52.38, 54.73, 58.07])

ac.mean(), ac.std(), f1.mean(), f1.std()

(55.35, 1.8541628838912738, 54.838, 2.014223423555589)

In [337]:
# + adapter, 5 steps, adapt_step=3
ac = np.array([51.20, 54.88, 54.22, 50.16, 56.48])
f1 = np.array([50.47, 54.74, 54.00, 49.36, 56.05])

ac.mean(), ac.std(), f1.mean(), f1.std()

(53.388, 2.353128980740325, 52.924, 2.567166531411626)

In [300]:
# + adapter, 1 steps, adapt_step=1
ac = np.array([48.61, 50.55, 50.39, 50.08, 50.62])
f1 = np.array([47.51, 49.59, 49.26, 48.89, 49.60])

ac.mean(), ac.std(), f1.mean(), f1.std()

(50.05, 0.7436396977031278, 48.97, 0.7750354830586804)

In [263]:
# + adapter, 1 steps, adapt_step=3
ac = np.array([49.64, 50.27, 52.95, 54.57, 51.52])
f1 = np.array([48.57, 49.19, 52.21, 53.48, 50.42])

ac.mean(), ac.std(), f1.mean(), f1.std()

(51.79, 1.793309789188694, 50.774, 1.837200043544523)

In [None]:
# base, 5 steps
ac = np.array([51.04, 50.19, 50.66, 52.23, 51.82])
f1 = np.array([50.07, 49.07, 49.66, 51.31, 50.86])

ac.mean(), ac.std(), f1.mean(), f1.std()

In [95]:
def test_tent_PNN(data_loader, tented_model):
    preds = []
    targets = []

    for batch_idx, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device)

        output = tented_model(data) # -distance: bsz, 10
        pred = torch.max(output.data, axis = 1)[1]

        preds.extend(pred.cpu().numpy())
        targets.extend(target.cpu().numpy())

    print('Acc: {:.2f}%'.format(100*accuracy_score(targets, preds)))
    print('F1 : {:.2f}%'.format(100*f1_score(targets, preds, average='macro')))
    
    return preds, targets

In [215]:
def custom_collect_params(model):
    params = []
    names = []
    for nm, m in model.named_modules():
        # BatchNorm2d 파라미터 (weight, bias)
        if isinstance(m, nn.BatchNorm2d):
            for np, p in m.named_parameters(recurse=False):
                if np in ['weight', 'bias']:
                    params.append(p)
                    names.append(f"{nm}.{np}")
        # Adapter 파라미터 추가 (ResidualAdapter 등)
        # - ResidualAdapter를 정확히 구분하려면 이름 또는 클래스 기반
        # if 'adapter' in nm or isinstance(m, ResidualAdapter):
        #     for np, p in m.named_parameters(recurse=False):
        #         params.append(p)
        #         names.append(f"{nm}.{np}")
    return params, names


def custom_configure_model(model):
    # train mode, because tent optimizes the model to minimize entropy
    model.train()
    # disable grad, to (re-)enable only what tent updates
    model.requires_grad_(False)
    # configure norm for tent updates: enable grad + force batch statisics
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.requires_grad_(True)
            # force use of batch stats in train and eval modes
            m.track_running_stats = False
            m.running_mean = None
            m.running_var = None
            
        if isinstance(m, ResidualAdapter):
            m.requires_grad_(True)
    return model
