In [7]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [8]:
%cd '/content/drive/My Drive/Fruit_States_RnC'

/content/drive/.shortcut-targets-by-id/1bNsY41NQ7yt_0gKqP3j-2gJ1pdli4usp/Fruit_States_RnC


In [9]:
import argparse
import os
import sys
import logging
import torch
import time
from model_fruits import Encoder, model_dict
from dataset_fruits import *
from utils_fruits import *

print = logging.info

In [10]:
def parse_option(args=None):
    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--print_freq', type=int, default=10, help='print frequency')
    parser.add_argument('--save_freq', type=int, default=50, help='save frequency')

    parser.add_argument('--batch_size', type=int, default=256, help='batch_size')
    parser.add_argument('--num_workers', type=int, default=16, help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=200, help='number of training epochs')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate')
    parser.add_argument('--lr_decay_rate', type=float, default=0.8, help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=0, help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--trial', type=str, default='0', help='id for recording multiple runs')

    parser.add_argument('--data_folder', type=str, default='./data_fruits', help='path to custom dataset')
    parser.add_argument('--dataset', type=str, default='FruitsDataset', choices=['FruitsDataset', 'FruitsDatasetV2', 'FruitsDatasetRGB', 'FruitsDataset30C'], help='dataset')
    parser.add_argument('--model', type=str, default='resnet18', choices=['resnet18', 'resnet50'])
    parser.add_argument('--resume', type=str, default='', help='resume ckpt path')
    parser.add_argument('--aug', type=str, default='crop,flip,color,grayscale,rotate', help='augmentations')

    parser.add_argument('--ckpt', type=str, default='save/FruitsDataset_models/RnC_FruitsDataset_resnet18_ep_400_lr_0.5_d_0.1_wd_0.0001_mmt_0.9_bsz_256_aug_crop,flip,rotate_temp_2_label_l1_feature_l2_trial_0/last.pth', help='path to the trained encoder')

    if args is None:
        args = []
    opt = parser.parse_args(args=args)

    opt.model_name = 'Regressor_{}_ep_{}_lr_{}_d_{}_wd_{}_mmt_{}_bsz_{}_trial_{}'. \
        format(opt.dataset, opt.epochs, opt.learning_rate, opt.lr_decay_rate,
               opt.weight_decay, opt.momentum, opt.batch_size, opt.trial)
    if len(opt.resume):
        opt.model_name = opt.resume.split('/')[-1][:-len('_last.pth')]
    opt.save_folder = '/'.join(opt.ckpt.split('/')[:-1])

    logging.root.handlers = []
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s | %(message)s",
        handlers=[
            logging.FileHandler(os.path.join(opt.save_folder, f'{opt.model_name}.log')),
            logging.StreamHandler()
        ]
    )

    print(f"Model name: {opt.model_name}")
    print(f"Options: {opt}")

    return opt

In [11]:
def set_loader(opt, num_channel=3):
    train_transform = get_transforms(split='train', aug=opt.aug, num_channel=num_channel)
    val_transform = get_transforms(split='val', aug=opt.aug, num_channel=num_channel)
    print(f"Train Transforms: {train_transform}")
    print(f"Val Transforms: {val_transform}")

    train_dataset = globals()[opt.dataset](
        data_folder=opt.data_folder,
        transform=train_transform,
        split='train'
    )
    val_dataset = globals()[opt.dataset](
        data_folder=opt.data_folder,
        transform=val_transform,
        split='val'
    )
    test_dataset = globals()[opt.dataset](
        data_folder=opt.data_folder,
        transform=val_transform,
        split='test'
    )
    # full_dataset = globals()[opt.dataset](
    #     data_folder=opt.data_folder,
    #     transform=val_transform,
    #     split='full'
    # )

    print(f'Train set size: {train_dataset.__len__()}\t'
          f'Val set size: {val_dataset.__len__()}\t'
          f'Test set size: {test_dataset.__len__()}')

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers, pin_memory=True
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers, pin_memory=True
    )
    # full_loader = torch.utils.data.DataLoader(
    #     full_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers, pin_memory=True
    # )

    # return train_loader, val_loader, test_loader, full_loader
    return train_loader, val_loader, test_loader


def set_model(opt, num_channel=3):
    model = Encoder(name=opt.model, in_channel=num_channel)
    criterion = torch.nn.L1Loss()

    dim_in = model_dict[opt.model][1]
    dim_out = get_label_dim(opt.dataset)
    regressor = torch.nn.Linear(dim_in, dim_out)
    ckpt = torch.load(opt.ckpt, map_location='cpu')
    state_dict = ckpt['model']

    if torch.cuda.device_count() > 1:
        model.encoder = torch.nn.DataParallel(model.encoder)
    else:
        new_state_dict = {}
        for k, v in state_dict.items():
            k = k.replace("module.", "")
            new_state_dict[k] = v
        state_dict = new_state_dict
    model = model.cuda()
    regressor = regressor.cuda()
    criterion = criterion.cuda()
    torch.backends.cudnn.benchmark = True

    model.load_state_dict(state_dict)
    print(f"<=== Epoch [{ckpt['epoch']}] checkpoint Loaded from {opt.ckpt}!")

    return model, regressor, criterion


def train(train_loader, model, regressor, criterion, optimizer, epoch, opt):
    model.eval()
    regressor.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    end = time.time()
    for idx, (images, labels) in enumerate(train_loader):
        data_time.update(time.time() - end)

        images = images.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        bsz = labels.shape[0]

        with torch.no_grad():
            features = model(images)

        output = regressor(features.detach())
        loss = criterion(output, labels)
        losses.update(loss.item(), bsz)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()

        if (idx + 1) % opt.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})'.format(
                epoch, idx + 1, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses))
            sys.stdout.flush()


def validate(val_loader, model, regressor):
    model.eval()
    regressor.eval()

    losses = AverageMeter()
    criterion_l1 = torch.nn.L1Loss()

    with torch.no_grad():
        for idx, (images, labels) in enumerate(val_loader):
            images = images.cuda()
            labels = labels.cuda()
            bsz = labels.shape[0]

            features = model(images)
            output = regressor(features)

            loss_l1 = criterion_l1(output, labels)
            losses.update(loss_l1.item(), bsz)

    return losses.avg

def validate_modified(loader, model, regressor, split=''):
  model.eval()
  regressor.eval()

  losses = AverageMeter()
  criterion_l1 = torch.nn.L1Loss()
  all_outputs = []

  with torch.no_grad():
      for idx, (images, labels) in enumerate(loader):
          images = images.cuda()
          labels = labels.cuda()
          bsz = labels.shape[0]

          features = model(images)
          output = regressor(features)

          all_outputs.append(output.detach().cpu().numpy())

          loss_l1 = criterion_l1(output, labels)
          losses.update(loss_l1.item(), bsz)

  all_outputs = np.concatenate(all_outputs, axis=0)
  np.save('outputs/predicted_y_' + split + '.npy', all_outputs)

  return losses.avg

def main_linear(args=None, num_channel=3):
  opt = parse_option(args)

  # build data loader
  # train_loader, val_loader, test_loader, full_loader = set_loader(opt)
  train_loader, val_loader, test_loader = set_loader(opt, num_channel=num_channel)

  # build model and criterion
  model, regressor, criterion = set_model(opt, num_channel=num_channel)

  # build optimizer
  optimizer = set_optimizer(opt, regressor)

  save_file_best = os.path.join(opt.save_folder, f"{opt.model_name}_best.pth")
  save_file_last = os.path.join(opt.save_folder, f"{opt.model_name}_last.pth")
  best_error = 1e5

  start_epoch = 1
  if len(opt.resume):
      ckpt_state = torch.load(opt.resume)
      regressor.load_state_dict(ckpt_state['state_dict'])
      start_epoch = ckpt_state['epoch'] + 1
      best_error = ckpt_state['best_error']
      print(f"<=== Epoch [{ckpt_state['epoch']}] Resumed from {opt.resume}!")


  # training routine
  for epoch in range(start_epoch, opt.epochs + 1):
      adjust_learning_rate(opt, optimizer, epoch)

      # train for one epoch
      train(train_loader, model, regressor, criterion, optimizer, epoch, opt)

      valid_error = validate(val_loader, model, regressor)
      print('Val L1 error: {:.3f}'.format(valid_error))

      is_best = valid_error < best_error
      best_error = min(valid_error, best_error)
      print(f"Best Error: {best_error:.3f}")

      if is_best:
          torch.save({
              'epoch': epoch,
              'state_dict': regressor.state_dict(),
              'best_error': best_error
          }, save_file_best)

      torch.save({
          'epoch': epoch,
          'state_dict': regressor.state_dict(),
          'last_error': valid_error
      }, save_file_last)

  print("=" * 120)
  print("Test best model on test set...")
  checkpoint = torch.load(save_file_best)
  regressor.load_state_dict(checkpoint['state_dict'])
  print(f"Loaded best model, epoch {checkpoint['epoch']}, best val error {checkpoint['best_error']:.3f}")
  # test_loss1 = validate_modified(full_loader, model, regressor, "full")
  # test_loss2 = validate_modified(train_loader, model, regressor, "train")
  # test_loss3 = validate_modified(val_loader, model, regressor, "val")
  # test_loss4 = validate_modified(test_loader, model, regressor, "test")
  test_loss = validate(test_loader, model, regressor)
  to_print = 'Test L1 error: {:.3f}'.format(test_loss)
  print(to_print)

In [15]:
args = [
    '--print_freq', '10',
    '--save_freq', '50',

    '--batch_size', '256',
    '--num_workers', '16',
    '--epochs', '200',
    '--learning_rate', '0.001',
    '--lr_decay_rate', '0.8',
    '--weight_decay', '0',
    '--momentum', '0.9',
    '--trial', '3',

    '--data_folder', './data_fruits_v2',
    '--dataset', 'FruitsDatasetV2',
    '--model', 'resnet18',
    '--resume', '',
    '--aug', 'crop,flip,rotate',

    # '--ckpt', 'save/FruitsDatasetRGB_models/RnC_FruitsDatasetRGB_resnet18_ep_400_lr_0.5_d_0.9_wd_0.0001_mmt_0.9_bsz_256_aug_crop,flip,rotate_temp_2.0_label_l1_feature_l2_trial_0/last.pth'
    # '--ckpt', 'save/FruitsDataset30C_models/RnC_FruitsDataset30C_resnet18_ep_400_lr_0.5_d_0.9_wd_0.0001_mmt_0.9_bsz_256_aug_crop,flip,rotate_temp_2.0_label_l1_feature_l2_trial_0/last.pth'
    '--ckpt', 'save_final/FruitsDatasetV2_models/RnC_FruitsDatasetV2_resnet18_ep_400_lr_0.5_d_0.9_wd_0.0001_mmt_0.9_bsz_256_aug_crop,flip,rotate_temp_2.0_label_l1_feature_l2_trial_0/ckpt_epoch_350.pth'
]

main_linear(args, num_channel=8)

2024-07-01 16:06:23,954 | Model name: Regressor_FruitsDatasetV2_ep_200_lr_0.001_d_0.8_wd_0.0_mmt_0.9_bsz_256_trial_3
2024-07-01 16:06:23,956 | Options: Namespace(print_freq=10, save_freq=50, batch_size=256, num_workers=16, epochs=200, learning_rate=0.001, lr_decay_rate=0.8, weight_decay=0.0, momentum=0.9, trial='3', data_folder='./data_fruits_v2', dataset='FruitsDatasetV2', model='resnet18', resume='', aug='crop,flip,rotate', ckpt='save_final/FruitsDatasetV2_models/RnC_FruitsDatasetV2_resnet18_ep_400_lr_0.5_d_0.9_wd_0.0001_mmt_0.9_bsz_256_aug_crop,flip,rotate_temp_2.0_label_l1_feature_l2_trial_0/ckpt_epoch_350.pth', model_name='Regressor_FruitsDatasetV2_ep_200_lr_0.001_d_0.8_wd_0.0_mmt_0.9_bsz_256_trial_3', save_folder='save_final/FruitsDatasetV2_models/RnC_FruitsDatasetV2_resnet18_ep_400_lr_0.5_d_0.9_wd_0.0001_mmt_0.9_bsz_256_aug_crop,flip,rotate_temp_2.0_label_l1_feature_l2_trial_0')
2024-07-01 16:06:23,957 | Train Transforms: Compose(
    RandomResizedCrop(size=(50, 50), scale=(0.2,