In [1]:
from argparse import Namespace
import torch
import torch.nn as nn
import torchvision.models as models
from get_dataset import GetTransformedDataset
from simclr import simclr_framework


class ResNet18(nn.Module):
    def __init__(self, out_dim):
        super(ResNet18, self).__init__()
        self.backbone = self._get_basemodel(out_dim)
        dim_mlp = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)

    def _get_basemodel(self, out_dim):
        return models.resnet18(weights=None, num_classes=out_dim)

    def forward(self, x):
        return self.backbone(x)

In [2]:
def train(args):

    print(args.device)

    dataset = GetTransformedDataset()
    train_dataset = dataset.get_cifar10_train(args.n_views)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)

    model = ResNet18(out_dim=args.out_dim)

    optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)

    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 1.0)
 
    simclr = simclr_framework(model=model, optimizer=optimizer, scheduler=scheduler, args=args)

    print("="*50)
    print(f"Trining SimCLR_TIN_lr{args.lr}_wd{args.weight_decay}_temperature{args.temperature}_bt{args.batch_size}_e{args.epochs}")
    print("="*50)
    print('training started..')
    simclr.train(train_loader) 
    print('training completed..')


In [3]:
temp_list = [0.07]
lr_list = [1e-3]
wd_list = [0]
for temp in temp_list:
    for lr in lr_list:
        for wd in wd_list:
            args = Namespace
            # Hyperparameters
            args.batch_size = 512
            args.epochs = 300
            args.lr = lr
            args.temperature = temp
            args.weight_decay = wd
            # Other settings
            args.fp16_precision = False
            args.device = torch.device('cuda')
            args.gpu_index = 0
            args.log_every_n_steps = 1
            args.n_views = 2
            args.out_dim = 100
            args.seed = 1
            args.workers = 8
            args.log_dir = f"/root/Lab3-1/logs/simclr_C10_lr{args.lr}_wd{args.weight_decay}_temp{args.temperature}_bt{args.batch_size}_e{args.epochs}"
            train(args)

In [None]:
# import logging
# import os
# import torch
# import torch.nn.functional as F
# from torch.cuda.amp import GradScaler, autocast
# from torch.utils.tensorboard import SummaryWriter
# from tqdm import tqdm
# from utils import save_config_file, accuracy, save_checkpoint

# torch.manual_seed(0)

# def setup_logging(log_dir, log_filename):
#     # 关闭已有的日志记录器
#     for handler in logging.root.handlers[:]:
#         logging.root.removeHandler(handler)
#     # 配置新的日志记录器
#     logging.basicConfig(filename=os.path.join(log_dir, log_filename), level=logging.DEBUG)

# class simclr_extra_train(object):

#     def __init__(self, *args, **kwargs):
#         self.args = kwargs['args']
#         self.model = kwargs['model'].to(self.args.device)
#         self.optimizer = kwargs['optimizer']
#         self.scheduler = kwargs['scheduler']
#         self.log_dir = self.args.log_dir
#         self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)

#     def info_nce_loss(self, features):

#         labels = torch.cat([torch.arange(self.args.batch_size)
#                            for i in range(self.args.n_views)], dim=0)
#         labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
#         labels = labels.to(self.args.device)

#         features = F.normalize(features, dim=1)

#         similarity_matrix = torch.matmul(features, features.T)
#         mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device)
#         labels = labels[~mask].view(labels.shape[0], -1)
#         similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

#         positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
#         negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

#         logits = torch.cat([positives, negatives], dim=1)
#         labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)

#         logits = logits / self.args.temperature
#         return logits, labels

#     def train(self, train_loader, start):

#         scaler = GradScaler(enabled=self.args.fp16_precision)
#         save_config_file(self.args.log_dir, self.args)

#         n_iter = 0
#         setup_logging(self.log_dir, 'training.log')
#         logging.info(f"Start SimCLR training for {self.args.epochs} epochs.")
#         logging.info(f"Training with gpu: {self.args.device}.")

#         for epoch_counter in range(start, self.args.epochs):
#             for images, _ in tqdm(train_loader):
#                 images = torch.cat(images, dim=0)

#                 images = images.to(self.args.device)

#                 with autocast(enabled=self.args.fp16_precision):
#                     features = self.model(images)
#                     logits, labels = self.info_nce_loss(features)
#                     loss = self.criterion(logits, labels)

#                 self.optimizer.zero_grad()

#                 scaler.scale(loss).backward()

#                 scaler.step(self.optimizer)
#                 scaler.update()
#                 if n_iter % self.args.log_every_n_steps == 0:
#                     top1, top5 = accuracy(logits, labels, topk=(1, 5))
#                 n_iter += 1

#             if epoch_counter >= 10:
#                 self.scheduler.step()
#             logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")


#         checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(epoch_counter+1)
#         save_checkpoint({
#             'epoch': epoch_counter,
#             'arch': 'resnet18',
#             'state_dict': self.model.state_dict(),
#             'optimizer': self.optimizer.state_dict(),
#         }, is_best=False, filename=os.path.join(args.log_dir, checkpoint_name))
#         logging.info(f"Model checkpoint and metadata has been saved at {args.log_dir}.")
#         logging.info("Training has finished.")


# args = Namespace
# # Hyperparameters
# args.batch_size = 512
# args.epochs = 300
# args.lr = 1e-3
# args.temperature = 0.07
# args.weight_decay = 0
# # Other settings
# args.fp16_precision = False
# args.device = torch.device('cuda')
# args.gpu_index = 0
# args.log_every_n_steps = 1
# args.n_views = 2
# args.out_dim = 100
# args.seed = 1
# args.workers = 8
# args.log_dir = f"/root/Lab3-1/logs/extra_simclr_C10_lr{args.lr}_wd{args.weight_decay}_temp{args.temperature}_bt{args.batch_size}_e{args.epochs}"

# print(args.device)

# dataset = GetTransformedDataset()
# train_dataset = dataset.get_cifar10_train(args.n_views)
# train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)

# model = ResNet18(out_dim=args.out_dim).to(args.device)
# optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 1.0)

# # 加载中途保存的权重
# checkpoint_path = './logs/simclr_C10_lr0.001_wd0_temp0.07_bt512_e300/checkpoint_0200.pth.tar'
# checkpoint = torch.load(checkpoint_path)

# model.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer'])
# start_epoch = 200 + 1  # 从上次的epoch继续训练

# simclr = simclr_extra_train(model=model, optimizer=optimizer, scheduler=scheduler, args=args)

# print("="*50)
# print(f"Trining SimCLR_TIN_lr{args.lr}_wd{args.weight_decay}_temperature{args.temperature}_bt{args.batch_size}_e{args.epochs}")
# print("="*50)
# print('training started..')
# simclr.train(train_loader, start_epoch)
# print('training completed..')

