# PRETRAINING

In [1]:
# utils.util

from __future__ import print_function

import math
import torch
import torch.optim as optim
import numpy as np
!pip install ood_metrics
from ood_metrics import calc_metrics
from sklearn.metrics import roc_curve



class TwoCropTransform:
    """Create two crops of the same image"""

    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return self.transform(x), self.transform(x)


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def adjust_learning_rate(args, optimizer, epoch):
    lr = args.learning_rate
    if args.cosine:
        eta_min = lr * (args.lr_decay_rate ** 3)
        lr = eta_min + (lr - eta_min) * (
                1 + math.cos(math.pi * epoch / args.epochs)) / 2
    else:
        steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
        if steps > 0:
            lr = lr * (args.lr_decay_rate ** steps)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
    if args.warm and epoch <= args.warm_epochs:
        p = (batch_id + (epoch - 1) * total_batches) / \
            (args.warm_epochs * total_batches)
        lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr


def set_optimizer(opt, model):
    optimizer = optim.SGD(model.parameters(),
                          lr=opt.learning_rate,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)
    return optimizer


def calc_metrics_transformed(ind_score: np.ndarray, ood_score: np.ndarray) -> dict:
    labels = [1] * len(ind_score) + [0] * len(ood_score)
    scores = np.hstack([ind_score, ood_score])

    metric_dict = calc_metrics(scores, labels)
    fpr, tpr, _ = roc_curve(labels, scores)

    metric_dict_transformed = {
        "AUROC": 100 * metric_dict["auroc"],
        #    "TNR at TPR 95%": 100 * (1 - metric_dict["fpr_at_95_tpr"]),
        #   "Detection Acc.": 100 * 0.5 * (tpr + 1 - fpr).max(),
    }
    return metric_dict_transformed



In [2]:
# conResNet

"""ResNet in PyTorch.
ImageNet-Style ResNet
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
Adapted from: https://github.com/bearpaw/pytorch-classification
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(BasicBlock, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(Bottleneck, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves
        # like an identity. This improves the model by 0.2~0.3% according to:
        # https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for i in range(num_blocks):
            stride = strides[i]
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, layer=100):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        return out


def resnet18(**kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)


def resnet34(**kwargs):
    return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)


def resnet50(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)


def resnet101(**kwargs):
    return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)


model_dict = {
    'resnet18': [resnet18, 512],
    'resnet34': [resnet34, 512],
    'resnet50': [resnet50, 2048],
    'resnet101': [resnet101, 2048],
}


class LinearBatchNorm(nn.Module):
    """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose"""

    def __init__(self, dim, affine=True):
        super(LinearBatchNorm, self).__init__()
        self.dim = dim
        self.bn = nn.BatchNorm2d(dim, affine=affine)

    def forward(self, x):
        x = x.view(-1, self.dim, 1, 1)
        x = self.bn(x)
        x = x.view(-1, self.dim)
        return x


class conResNet(nn.Module):
    """backbone + projection head"""

    def __init__(self, name='resnet50', head='mlp', feat_dim=128, n_heads=5):
        super(conResNet, self).__init__()

        model_fun, dim_in = model_dict[name]
        self.total_var = 0
        self.encoder = model_fun()
        self.proj = []
        self.n_heads = n_heads
        if head == 'linear':
            self.head = nn.Linear(dim_in, feat_dim)
        elif head == 'mlp':
            self.proj = nn.ModuleList()
            for _ in range(n_heads):
                pro = nn.Sequential(
                    nn.Linear(dim_in, dim_in),
                    nn.ReLU(inplace=True),
                    nn.Linear(dim_in, feat_dim)
                )
                self.proj.append(pro)

        else:
            raise NotImplementedError(
                'head not supported: {}'.format(head))

    def forward(self, x1, x2):
        f1 = self.encoder(x1)
        f2 = self.encoder(x2)
        res1 = []
        res2 = []
        for i in range(self.n_heads):
            res1.append(F.normalize(self.proj[i](f1), dim=1))
            res2.append(F.normalize(self.proj[i](f2), dim=1))
        feat1 = torch.mean(torch.stack(res1), dim=0)
        feat2 = torch.mean(torch.stack(res2), dim=0)
        feat1_std = torch.sqrt(torch.var(torch.stack(res1), dim=0) + 0.0001)
        feat2_std = torch.sqrt(torch.var(torch.stack(res2), dim=0) + 0.0001)
        features = torch.cat([feat1.unsqueeze(1), feat2.unsqueeze(1)], dim=1)
        features_std = torch.cat([feat1_std.unsqueeze(1), feat2_std.unsqueeze(1)], dim=1)

        return features, features_std


class LinearClassifier(nn.Module):
    """Linear classifier"""

    def __init__(self, name='resnet50', num_classes=10):
        super(LinearClassifier, self).__init__()

        _, dim_in = model_dict[name]
        self.fc = nn.Linear(dim_in, num_classes)

    def forward(self, features):
        return self.fc(features)


class MultiHeadSegResNet(nn.Module):
    def __init__(self, name='resnet50', num_classes=10, n_heads=5):
        super(MultiHeadSegResNet, self).__init__()
        model_fun, _ = model_dict[name]
        self.encoder = model_fun()
        self.n_heads = n_heads
        
        # Create multiple segmentation heads
        self.segmentation_heads = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(512, 256, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                nn.Conv2d(256, 128, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                nn.Conv2d(128, num_classes, kernel_size=1)
            ) for _ in range(n_heads)
        ])
    
    def forward(self, x):
        # Encode the input
        encoded_features = self.encoder(x)
        
        # Get segmentation output from each head
        seg_maps = [head(encoded_features) for head in self.segmentation_heads]
        
        # Stack outputs to compute mean and variance
        seg_maps_stack = torch.stack(seg_maps, dim=0)
        
        # Compute mean and variance across the heads for each pixel
        mean_seg_maps = torch.mean(seg_maps_stack, dim=0)
        variance_seg_maps = torch.var(seg_maps_stack, dim=0)
        
        return mean_seg_maps, variance_seg_maps, seg_maps

In [4]:
# UALoss

from __future__ import print_function

import torch
import torch_xla.core.xla_model as xm
import torch.nn as nn
import torch.nn.functional as F


class UALoss(nn.Module):

    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07, lamda1=1, lamda2=0.1, batch_size=512):
        super(UALoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature
        self.lamda1 = lamda1
        self.lamda2 = lamda2
        self.batch_size = batch_size

    def forward(self, features, features_std, epochs):
        """

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        # print(features_std.shape)
        try:
            # Attempt to create a TPU device
            device = xm.xla_device()
            device_type = 'TPU'
        except RuntimeError as e:
            # Fallback to CUDA or CPU
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            device_type = 'GPU' if torch.cuda.is_available() else 'CPU'

        print(f"Using device: {device_type}, {device}")

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]

        mask = torch.eye(batch_size, dtype=torch.float32).to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        anchor_feature = contrast_feature
        anchor_count = contrast_count

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        # uncertainty loss
        std_loss1 = torch.sum(F.relu(self.lamda2 - features_std)) / (2 * self.batch_size)
        std_loss2 = torch.sum(features_std) / (2 * self.batch_size)
        # print(std_loss)
        # nt xnet loss
        loss = loss.view(anchor_count, batch_size).mean()

        if self.lamda1 > 0:
            total_loss = std_loss1 * self.lamda1 + loss
        else:
            total_loss = loss

        return total_loss, std_loss1, std_loss2

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
# pretrain.py

import sys
import time
import torch
# from utils.util import AverageMeter
# from utils.util import warmup_learning_rate
# from models.resnet_big import conResNet
# from utils.losses import UALoss
import torch.backends.cudnn as cudnn
from torch import nn
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.core.xla_model import optimizer_step


def train(train_loader, model, criterion, optimizer, epoch, opt):
    """one epoch training"""
    device = xm.xla_device()
    model.train()
    
    # Metrics initialization
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    stdlosses = AverageMeter()
    stdlosses2 = AverageMeter()

    # Convert DataLoader for TPU usage
    para_loader = pl.ParallelLoader(train_loader, [device])
    loader = para_loader.per_device_loader(device)

    end = time.time()
    
    for idx, ((image1, image2), labels) in enumerate(loader):
        data_time.update(time.time() - end)

        image1 = image1.to(device)
        image2 = image2.to(device)
        labels = labels.to(device)
        bsz = labels.shape[0]

        # Warm-up and learning rate adjustment
        # Note: You'll need to adjust or implement warmup_learning_rate for TPUs if necessary

        # Forward pass, loss computation, and backward pass
        features, features_std = model(image1, image2)
        loss, std_loss, std_loss2 = criterion(features, features_std, epoch)
        
        losses.update(loss.item(), bsz)
        stdlosses.update(std_loss.item(), bsz)
        stdlosses2.update(std_loss2.item(), bsz)
        
        optimizer.zero_grad()
        loss.backward()
        xm.optimizer_step(optimizer)
        xm.mark_step()  # Marking the step is crucial for TPU computation

        batch_time.update(time.time() - end)
        end = time.time()

        if (idx + 1) % opt.print_freq == 0:
            print(f'Train: [{epoch}][{idx + 1}/{len(loader)}]\t'
                  f'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  f'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  f'Loss {losses.val:.3f} ({losses.avg:.3f})')

    return losses.avg, stdlosses.avg, stdlosses2.avg



def set_model(model_name, temperature, syncBN=False, lamda1=1, lamda2=0.1, batch_size=512, nh=5):
    # Define the device as TPU or fall back to CUDA/CPU
    device = xm.xla_device()
    
    model = conResNet(name=model_name, n_heads=nh)
    criterion = UALoss(temperature=temperature, lamda1=lamda1, lamda2=lamda2, batch_size=batch_size)
    
    model = model.to(device)
    criterion = criterion.to(device)
    
    # Synchronized batch normalization using PyTorch XLA's implementation
    if syncBN:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    return model, criterion

In [7]:
# dataloader

from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from torchvision import datasets
import torch
# from utils.util import TwoCropTransform


def data_loader(dataset="cifar10", batch_size=512, semi=False, semi_percent=10, num_cores=12):
    if dataset == 'cifar10':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
    elif dataset == 'cifar100':
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)
    elif dataset == 'cifar10h':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
    elif dataset == "svhn":
        mean = (0.4376821, 0.4437697, 0.47280442)
        std = (0.19803012, 0.20101562, 0.19703614)

    normalize = transforms.Normalize(mean=mean, std=std)
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
    sampler = None
    # datasets
    if dataset == "cifar10":
        train_dataset = datasets.CIFAR10(root='../../DATA2/', train=True, download=True, transform=train_transform)
        test_dataset = datasets.CIFAR10(root='../../DATA2/', train=False, download=True, transform=val_transform)

    elif dataset == "cifar100":
        train_dataset = datasets.CIFAR100(root='../../DATA2/', train=True, download=True, transform=train_transform)
        test_dataset = datasets.CIFAR100(root='../../DATA2/', train=False, download=True, transform=val_transform)
    elif dataset == "svhn":
        train_dataset = datasets.SVHN(
            root='.../../DATA2/', split="train", download=True, transform=train_transform
        )
        test_dataset = datasets.SVHN(
            root='../../DATA2/', split="test", download=True, transform=val_transform
        )
    if semi:
        per = semi_percent / 100
        x = int(per * len(train_dataset))
        y = int(len(train_dataset) - x)
        train, _ = random_split(train_dataset, [x, y])
    else:
        train = train_dataset

    train, val = random_split(train, [int(0.8 * len(train)),
                                      len(train) - int(0.8 * len(train))], generator=torch.Generator().manual_seed(42))

    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_cores,
                              drop_last=False
                              )
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_cores,
                            drop_last=False)

    test_loader = DataLoader(test_dataset,
                             batch_size=batch_size,
                             num_workers=num_cores,
                             drop_last=False)

    targets = torch.cat([y for x, y in test_loader], dim=0).numpy()
    return train_loader, val_loader, test_loader, targets


def set_loader_simclr(dataset, batch_size, num_workers, data_dir='../../DATA2/', size=32):
    # construct data loader
    if dataset == 'cifar10':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
    elif dataset == 'cifar100':
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)
    elif dataset == "svhn":
        mean = (0.4376821, 0.4437697, 0.47280442)
        std = (0.19803012, 0.20101562, 0.19703614)
    else:
        raise ValueError('dataset not supported: {}'.format(dataset))
    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
        normalize,
    ])

    if dataset == 'cifar10':
        train_dataset = datasets.CIFAR10(root=data_dir,
                                         transform=TwoCropTransform(train_transform),
                                         download=True)
    elif dataset == 'cifar100':
        train_dataset = datasets.CIFAR100(root=data_dir,
                                          transform=TwoCropTransform(train_transform),
                                          download=True)
    elif dataset == 'svhn':
        train_dataset = datasets.SVHN(
            root=data_dir, split="train", download=True, transform=TwoCropTransform(train_transform)
        )

    else:
        raise ValueError(dataset)

    train_sampler = None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
        num_workers=num_workers, pin_memory=False, sampler=train_sampler)

    return train_loader

In [8]:
from __future__ import print_function

import sys
import os
import argparse
import time
import math
import pandas as pd
!pip install tensorboard_logger
import tensorboard_logger as tb_logger
import torch
# from Train.pretrain import train, set_model
# from Dataloader.dataloader import set_loader_simclr

# from utils.util import adjust_learning_rate
# from utils.util import set_optimizer

# try:
#    import apex
#    from apex import amp, optimizers
# except ImportError:
#    pass


def parse_option():
    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--print_freq', type=int, default=10,
                        help='print frequency')
    parser.add_argument('--batch_size', type=int, default=64,
                        help='batch_size')
    parser.add_argument('--num_workers', type=int, default=12,
                        help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=800,
                        help='number of training epochs')
    parser.add_argument('--ensemble', type=int, default=1,
                        help='number of ensemble')
    # optimization
    parser.add_argument('--learning_rate', type=float, default=0.01,
                        help='learning rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900',
                        help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1,
                        help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-4,
                        help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='momentum')

    # model dataset
    parser.add_argument('--model', type=str, default='resnet50')
    parser.add_argument('--dataset', type=str, default='cifar10',
                        choices=['cifar10', 'cifar100'], help='dataset')
    parser.add_argument('--data_folder', type=str, default=None, help='path to custom dataset')
    parser.add_argument('--size', type=int, default=32, help='parameter for RandomResizedCrop')

    # hyperparameters
    parser.add_argument('--temp', type=float, default=0.07,
                        help='temperature for loss function')
    parser.add_argument('--nh', type=int, default=10,
                        help='number of heads')
    parser.add_argument('--lamda1', type=float, default=1,
                        help='uncertainty_penalty_weight')
    parser.add_argument('--lamda2', type=float, default=0.8,
                        help='uncertainty_threshold')

    # other setting
    parser.add_argument('--cosine', action='store_true',
                        help='using cosine annealing')
    parser.add_argument('--warm', action='store_true',
                        help='warm-up for large batch training')
    parser.add_argument('--saved_model', type=str, default=".",
                        help='path to save classifier')
    parser.add_argument('--log', type=str, default='.',
                        help='path to save tensorboard logs')
    parser.add_argument('--syncBN', action='store_true', 
                        help='enable synchronized batch normalization')
    

    opt = parser.parse_args()

    # set the path according to the environment
    if opt.data_folder is None:
        opt.data_folder = './DATA'
    opt.tb_path = os.path.join(opt.log, '{}'.format(opt.dataset))
    print(f"THIS: {opt.saved_model}")
    opt.save_folder = os.path.join(opt.saved_model, '{}_experiments'.format(opt.dataset))
    if not os.path.isdir(opt.save_folder):
        os.makedirs(opt.save_folder)
    if opt.batch_size > 256:
        opt.warm = True
    if opt.warm:
        opt.warmup_from = 0.01
        opt.warm_epochs = 10
        if opt.cosine:
            eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
            opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
                    1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
        else:
            opt.warmup_to = opt.learning_rate
    return opt


def main():
    sys.argv = ['main_pretrain.py', '--cosine', '--dataset', 'cifar10', '--lamda1', '1', '--lamda2', '0.08', '--epochs', '800']
    opt = parse_option()

    torch.cuda.empty_cache()
    # build data loader
    train_loader = set_loader_simclr(dataset=opt.dataset, batch_size=opt.batch_size, num_workers=opt.num_workers,
                                     size=opt.size)

    for i in range(opt.ensemble):
        torch.manual_seed(i)
        torch.cuda.manual_seed(i)
        model, criterion = set_model(model_name=opt.model, temperature=opt.temp, syncBN=opt.syncBN, lamda1=opt.lamda1,
                                     lamda2=opt.lamda2,
                                     batch_size=opt.batch_size, nh=opt.nh)

        # build optimizer
        optimizer = set_optimizer(opt, model)

        # tensorboard
        logger = tb_logger.Logger(logdir=opt.tb_path, flush_secs=2)

        time1 = time.time()
        l1 = []
        l2 = []
        l3 = []
        # training routine
        for epoch in range(1, opt.epochs + 1):
            adjust_learning_rate(opt, optimizer, epoch)
            # train for one epoch
            time3 = time.time()
            loss, std_loss, std_loss2 = train(train_loader, model, criterion, optimizer, epoch, opt)
            time4 = time.time()
            print('ensemble {}, epoch {}, total time {:.2f}'.format(i, epoch, time4 - time3))
            l1.append(loss)
            l2.append(std_loss)
            l3.append(std_loss2)
            # tensorboard logger
            logger.log_value('loss', loss, epoch)
            logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)
            # logger.log_value('std', std_loss, epoch)
            checkpoint_file = os.path.join(
            ".",
            'simclr_{}_{}_recent_{}heads_lamda1{}_lamda2{}.pth'.format(opt.dataset, i, opt.nh,
                                                                          opt.lamda1,
                                                                          opt.lamda2))
            torch.save(model.state_dict(), checkpoint_file)

        time2 = time.time()
        print('ensemble {}, total time {:.2f}'.format(i, time2 - time1))
        loss_res = pd.DataFrame({"total_loss": l1, "stdloss1": l2, "stdloss2": l3})
        os.makedirs("./csv_loss", exist_ok=True)
        loss_res.to_csv(
            "./csv_loss/{}_c_{}heads_lamda1{}_lamda2{}.csv".format(opt.dataset, opt.nh, opt.lamda1, opt.lamda2))
        save_file = os.path.join(
            opt.save_folder,
            'simclr_{}_{}_epoch{}_{}heads_lamda1{}_lamda2{}.pth'.format(opt.dataset, i, opt.epochs, opt.nh,
                                                                          opt.lamda1,
                                                                          opt.lamda2))
        torch.save(model.state_dict(), save_file)


if __name__ == '__main__':
    main()

THIS: .
Files already downloaded and verified


I0000 00:00:1710869146.725114   12251 cpu_client.cc:370] TfrtCpuClient created.


Using device: TPU, xla:0
