In [1]:
from utils import *
import os
from data import *
from torchvision.models import resnet50, ResNet50_Weights
from torch import nn
from train import *

In [2]:
reweight_groups = False
automatic_adjustment = False
use_normalized_loss = False
lr = 0.001
alpha = 0.2
batch_size = 64
weight_decay = 0
gamma = 0.1
n_epochs = 10
seed = 0
generalization_adjustment = "0"
log_every = 50
model = "resnet50"
log_dir = "./logs"
save_best = True
save_last = False
mode = 'w'
robust = False
robust_step_size = 0.01
log_every = 50
save_step = 10
save_best = True
save_last = False

In [3]:
if not os.path.exists(log_dir):
        os.makedirs(log_dir)
logger = Logger(os.path.join(log_dir, 'log.txt'), mode=mode)
set_seed(seed)

train_data, val_data, test_data = get_confounder_splits(
    data_dir="../datasets/waterbird", augment_data=False)

loader_kwargs = {'batch_size': batch_size, 'pin_memory': False}
train_loader = train_data.get_loader(
    train=True, reweight_groups=reweight_groups, **loader_kwargs)
val_loader = val_data.get_loader(
    train=False, reweight_groups=False, ** loader_kwargs)
test_loader = test_data.get_loader(
    train=False, reweight_groups=False, **loader_kwargs)

print(
    f'Train data: {train_data.__len__()} | Val data:  {val_data.__len__()} | Test data:  {test_data.__len__()}')

data = {}
data['train_loader'] = train_loader
data['val_loader'] = val_loader
data['test_loader'] = test_loader
data['train_data'] = train_data
data['val_data'] = val_data
data['test_data'] = test_data
n_classes = train_data.n_classes

Train data: 4795 | Val data:  1199 | Test data:  1199


In [4]:
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2, progress=True)
model.fc = nn.Linear(in_features=model.fc.in_features, out_features=2)

logger.flush()

criterion = nn.CrossEntropyLoss(reduction='none')

train_csv_logger = CSVBatchLogger(os.path.join(
    log_dir, 'train.csv'), train_data.n_groups, mode=mode)
val_csv_logger = CSVBatchLogger(os.path.join(
    log_dir, 'val.csv'), train_data.n_groups, mode=mode)
test_csv_logger = CSVBatchLogger(os.path.join(
    log_dir, 'test.csv'), train_data.n_groups, mode=mode)

In [5]:
train(
    model=model,
    criterion=criterion,
    dataset=data,
    logger=logger,
    train_csv_logger=train_csv_logger,
    val_csv_logger=val_csv_logger,
    test_csv_logger=test_csv_logger,
    alpha=alpha,
    gamma=gamma,
    robust=robust,
    robust_step_size=robust_step_size,
    use_normalized_loss=use_normalized_loss,
    generalization_adjustment=generalization_adjustment,
    lr=lr,
    weight_decay=weight_decay,
    n_epochs=n_epochs,
    log_every=log_every,
    save_step=save_step,
    save_last=save_last,
    save_best=save_best,
    log_dir=log_dir,
    reweight_groups=reweight_groups,
    automatic_adjustment=automatic_adjustment)

train_csv_logger.close()
val_csv_logger.close()
test_csv_logger.close()


Epoch [0]:
Training:


 65%|██████▌   | 49/75 [10:24<05:31, 12.76s/it]

Average incurred loss: 0.439  
Average sample loss: 0.439  
Average acc: 0.799  
[n = 2324]:	loss = 0.269  exp loss = 0.155  adjusted loss = 0.155  adv prob = 0.250000   acc = 0.983
[n = 132]:	loss = 0.438  exp loss = 0.522  adjusted loss = 0.522  adv prob = 0.250000   acc = 0.939
[n = 45]:	loss = 1.345  exp loss = 1.477  adjusted loss = 1.477  adv prob = 0.250000   acc = 0.000
[n = 699]:	loss = 0.947  exp loss = 0.741  adjusted loss = 0.741  adv prob = 0.250000   acc = 0.212


100%|██████████| 75/75 [16:09<00:00, 12.92s/it]

Average incurred loss: 0.223  
Average sample loss: 0.223  
Average acc: 0.940  
[n = 1174]:	loss = 0.111  exp loss = 0.103  adjusted loss = 0.103  adv prob = 0.250000   acc = 0.999
[n = 52]:	loss = 0.532  exp loss = 0.493  adjusted loss = 0.493  adv prob = 0.250000   acc = 0.750
[n = 11]:	loss = 1.417  exp loss = 1.471  adjusted loss = 1.471  adv prob = 0.250000   acc = 0.182
[n = 358]:	loss = 0.510  exp loss = 0.518  adjusted loss = 0.518  adv prob = 0.250000   acc = 0.796






Validation:


100%|██████████| 19/19 [01:26<00:00,  4.54s/it]

Average incurred loss: 0.509  
Average sample loss: 0.506  
Average acc: 0.716  
[n = 467]:	loss = 0.096  exp loss = 0.095  adjusted loss = 0.095  adv prob = 0.250000   acc = 0.996
[n = 466]:	loss = 0.649  exp loss = 0.613  adjusted loss = 0.613  adv prob = 0.250000   acc = 0.577
[n = 133]:	loss = 1.575  exp loss = 1.636  adjusted loss = 1.636  adv prob = 0.250000   acc = 0.068
[n = 133]:	loss = 0.403  exp loss = 0.432  adjusted loss = 0.432  adv prob = 0.250000   acc = 0.865





\Test:


100%|██████████| 19/19 [01:18<00:00,  4.12s/it]

Average incurred loss: 0.509  
Average sample loss: 0.506  
Average acc: 0.716  
[n = 467]:	loss = 0.096  exp loss = 0.095  adjusted loss = 0.095  adv prob = 0.250000   acc = 0.996
[n = 466]:	loss = 0.649  exp loss = 0.613  adjusted loss = 0.613  adv prob = 0.250000   acc = 0.577
[n = 133]:	loss = 1.575  exp loss = 1.636  adjusted loss = 1.636  adv prob = 0.250000   acc = 0.068
[n = 133]:	loss = 0.403  exp loss = 0.432  adjusted loss = 0.432  adv prob = 0.250000   acc = 0.865





Current lr: 0.001000
Current validation accuracy: 0.7155963182449341
Best model saved at epoch 0


Epoch [1]:
Training:


 65%|██████▌   | 49/75 [10:48<05:42, 13.18s/it]

Average incurred loss: 0.178  
Average sample loss: 0.178  
Average acc: 0.947  
[n = 2319]:	loss = 0.076  exp loss = 0.062  adjusted loss = 0.062  adv prob = 0.250000   acc = 0.997
[n = 124]:	loss = 0.738  exp loss = 0.856  adjusted loss = 0.856  adv prob = 0.250000   acc = 0.548
[n = 38]:	loss = 1.724  exp loss = 1.777  adjusted loss = 1.777  adv prob = 0.250000   acc = 0.132
[n = 719]:	loss = 0.330  exp loss = 0.289  adjusted loss = 0.289  adv prob = 0.250000   acc = 0.894


 95%|█████████▍| 71/75 [15:52<00:53, 13.41s/it]


KeyboardInterrupt: 