In [1]:
# imports
import matplotlib.pyplot as plt
import matplotlib
import cv2
import os
import torch 
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import random
import sys

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision

sys.path.append('../')

from tqdm import tqdm

from resnet_auen_hidden import *
from main import *

In [2]:
import easydict 
args = easydict.EasyDict({ "batch-size": 256, 
                          "epochs": 20, 
                          "data": 0, 
                          'arch':'resnet18',
                          'lr':0.1,
                         'momentum':0.9,
                         'weight_decay':1e-4,
                         'start_epoch':0,
                         'gpu':0,
                          'print_freq':1000})

In [3]:
device = 'cuda'
backbone_model= torchvision.models.resnet18(pretrained=True)
model = resnet18(pretrained=False)
nn.init.constant_(model.auen.bn1.weight, 0)
nn.init.constant_(model.auen.bn2.weight, 0)
print(model.conv1.weight[0,0,0])
model.load_state_dict(backbone_model.state_dict(), strict=False)
print(model.conv1.weight[0,0,0])
model.to(device)
criterion = nn.CrossEntropyLoss().to(device)
# optimizer = torch.optim.SGD(model.parameters(), args.lr,
#                             momentum=args.momentum,
#                             weight_decay=args.weight_decay)
optimizer = optim.Adam(model.parameters(), lr=0.0002)

linear layer is initialized
linear layer is initialized
linear layer is initialized
tensor([ 0.0110, -0.0367, -0.0128, -0.0016, -0.0285, -0.0435, -0.0510],
       grad_fn=<SelectBackward>)
tensor([-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],
       grad_fn=<SelectBackward>)


In [4]:
model.auen.bn1.weight

Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0', requires_grad=True)

In [5]:
# Data loading code
data_dir = '../ILSVRC/Data/CLS-LOC/'
traindir = os.path.join(data_dir, 'train')
valdir = os.path.join(data_dir, 'val')

In [6]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

train_dataset = datasets.ImageFolder(
    traindir,
    transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))
val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ]))

In [7]:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=256, shuffle=(train_sampler is None),
    num_workers=8, pin_memory=True, sampler=train_sampler)

In [8]:
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=32, shuffle=False,
    num_workers=4, pin_memory=True)

In [9]:
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (auen): AutoEncoder(
    (fc1): Linear(in_features=512, out_features=32, bias=True)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc2): Linear(in_features=32, out_features=512, bias=True)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (gelu): GELU()
  )
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (gelu): GELU()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(

In [None]:
best_acc1 = 0
acc1 = 0
train_loss = []
val_acc = []
for epoch in range(args.start_epoch, args.epochs):
    start_time = time.time()
#     adjust_learning_rate(optimizer, epoch, args)

    # train for one epoch
    epoch_loss = train(train_loader, model, criterion, optimizer, epoch, args)

    # evaluate on validation set
    acc1 = validate(val_loader, model, criterion, args)  
    
    train_loss.append(epoch_loss)
    val_acc.append(acc1)
    print('************train_loss {} val_acc {}*************'.format(epoch_loss, acc1))
    
    # remember best acc@1 and save checkpoint
    is_best = acc1 > best_acc1
    best_acc1 = max(acc1, best_acc1)

#     if not args.multiprocessing_distributed or (args.multiprocessing_distributed
#             and args.rank % ngpus_per_node == 0):
    save_checkpoint({
        'epoch': epoch + 1,
        'arch': args.arch,
        'state_dict': model.state_dict(),
        'best_acc1': best_acc1,
        'optimizer' : optimizer.state_dict(),
    }, is_best)
    
    print('epoch time', time.time() - start_time)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Epoch: [0][   0/5005]	Time  2.909 ( 2.909)	Data  2.300 ( 2.300)	Loss 3.2712e+00 (3.2712e+00)	Acc@1  39.45 ( 39.45)	Acc@5  58.20 ( 58.20)
Epoch: [0][1000/5005]	Time  0.330 ( 0.340)	Data  0.000 ( 0.003)	Loss 2.0268e+00 (2.0182e+00)	Acc@1  53.12 ( 54.52)	Acc@5  74.61 ( 77.32)
Epoch: [0][2000/5005]	Time  0.357 ( 0.341)	Data  0.000 ( 0.001)	Loss 2.1739e+00 (1.9829e+00)	Acc@1  52.73 ( 55.13)	Acc@5  75.78 ( 77.87)
Epoch: [0][3000/5005]	Time  0.342 ( 0.342)	Data  0.000 ( 0.001)	Loss 2.1030e+00 (1.9606e+00)	Acc@1  55.86 ( 55.51)	Acc@5  77.34 ( 78.22)
Epoch: [0][4000/5005]	Time  0.356 ( 0.343)	Data  0.000 ( 0.001)	Loss 1.8116e+00 (1.9474e+00)	Acc@1  59.38 ( 55.77)	Acc@5  81.25 ( 78.44)
Epoch: [0][5000/5005]	Time  0.326 ( 0.340)	Data  0.000 ( 0.001)	Loss 1.9038e+00 (1.9351e+00)	Acc@1  53.91 ( 56.01)	Acc@5  78.52 ( 78.64)
Test: [   0/1565]	Time  0.347 ( 0.347)	Loss 8.6259e-01 (8.6259e-01)	Acc@1  78.12 ( 78.12)	Acc@5  93.75 ( 93.75)
Test: [1000/1565]	Time  0.015 ( 0.050)	Loss 2.2628e+00 (1.4806e+00