In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
import torch
import numpy as np
import time
import tqdm
import shutil
import torch.nn as nn
import torchvision.models as  models
from torch.utils.data import DataLoader, ConcatDataset
import torchvision.datasets.voc as voc
import torch.optim as optim
from PIL import Image
from glob import glob
from tqdm import tqdm
from torchvision import transforms
from torchvision.models import resnet18
import torch.utils.model_zoo as model_zoo
import xml.etree.cElementTree as ET

In [3]:
os.chdir('/content/drive/MyDrive')

In [None]:
!unzip pascal-aug-20230514T005135Z-001.zip

Archive:  pascal-aug-20230514T005135Z-001.zip
replace pascal-aug/pascal-0-4/VOCdevkit/VOC2012/ImageSets/Main/train.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [5]:
data_dir = '/content/drive/MyDrive/pascal'
aug_dir = '/content/drive/MyDrive/CS444Project/Project/data/pascal-aug/pascal-0-8'
ckpt_dir = '/content/drive/MyDrive/CS444Project/Project/checkpoints'
object_categories = ['aeroplane', 'bicycle', 'bird', 'boat',
                     'bottle', 'bus', 'car', 'cat', 'chair',
                     'cow', 'diningtable', 'dog', 'horse',
                     'motorbike', 'person', 'pottedplant',
                     'sheep', 'sofa', 'train', 'tvmonitor']
num_classes = len(object_categories)
batch_size = 32
resnet_lr = 1e-5
fc_lr = 5e-3
num_epochs = 25

mean = [0.457342265910642, 0.4387686270106377, 0.4073427106250871]
std = [0.26753769276329037, 0.2638145880487105, 0.2776826934044154]

device = 'cuda' if torch.cuda.is_available() else 'cpu'
np.random.seed(1902)
torch.manual_seed(1902)

<torch._C.Generator at 0x7ff8bc5ca070>

## Data Pipeline

Download the PASCAL VOC dataset and create train and val data loaders.

In [6]:
class PascalVOC_Dataset(voc.VOCDetection):
    """Pascal VOC Detection Dataset"""
    def __init__(self, root, image_set='train', download=False, transform=None, target_transform=None):
        super().__init__(root, image_set=image_set, download=download, transform=transform, target_transform=target_transform)
    
    def __getitem__(self, index):
        return super().__getitem__(index)
    
    def __len__(self):
        return len(self.images)

In [7]:
def encode_labels(target):
    """Encode multiple labels using 1/0 encoding"""
    ls = target['annotation']['object']
    j = []
    if type(ls) == dict:
        if int(ls['difficult']) == 0:
            j.append(object_categories.index(ls['name']))
    else:
        for i in range(len(ls)):
            if int(ls[i]['difficult']) == 0:
                j.append(object_categories.index(ls[i]['name']))
    k = np.zeros(len(object_categories))
    k[j] = 1
    return torch.from_numpy(k)

In [8]:
transformations = transforms.Compose([transforms.Resize((300, 300)),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=mean, std=std)])
transformations_valid = transforms.Compose([transforms.Resize(330), 
                                            transforms.CenterCrop(300), 
                                            transforms.ToTensor(),
                                            transforms.Normalize(mean=mean, std=std)])

In [9]:
dataset_train = PascalVOC_Dataset(data_dir,
                                  image_set='train', 
                                  download=False, 
                                  transform=transformations, 
                                  target_transform=encode_labels)
dataset_aug = PascalVOC_Dataset(aug_dir,
                                  image_set='train', 
                                  download=False, 
                                  transform=transformations, 
                                  target_transform=encode_labels)

dataset_combined = ConcatDataset([dataset_train, dataset_aug])

train_loader = DataLoader(dataset_combined, batch_size=batch_size, num_workers=2, shuffle=True)

dataset_valid = PascalVOC_Dataset(data_dir, 
                                  image_set='val', 
                                  download=False, 
                                  transform=transformations_valid, 
                                  target_transform=encode_labels)
valid_loader = DataLoader(dataset_valid, batch_size=batch_size, num_workers=2)

In [None]:
# for i in glob('/content/drive/MyDrive/cs444/project/da-fusion/aug/textual-inversion-0.5/pascal-0-8/*.png'):
#   im = Image.open(i)
#   filepath = '/content/drive/MyDrive/cs444/project/da-fusion/pascal-aug/pascal-0-8/VOCdevkit/VOC2012/JPEGImages/' + i.split('/')[-1].split('.')[0] + '.jpg'
#   im.save(filepath)


In [None]:
# for j in range(20): 
#   for i in sorted(glob('/content/drive/MyDrive/cs444/project/da-fusion/pascal-aug/pascal-0-8/VOCdevkit/VOC2012/JPEGImages/*.jpg')):
#     num = i.split('/')[-1].split('-')[1]
#     name = i.split('/')[-1]
#     cat = object_categories[j]
 
#     if int(num) >= 8*j and int(num) < 8*j + 8:
#       root = ET.Element("annotation")
#       ET.SubElement(root, "folder", name="folder").text = "VOC2012"
#       ET.SubElement(root, "filename", name="filename").text = name

#       source = ET.SubElement(root, "source")

#       ET.SubElement(source, "database", name="database").text = "The VOC2007 Database"
#       ET.SubElement(source, "annotation", name="annotation").text = "PASCAL VOC2007"
#       ET.SubElement(source, "image", name="image").text = "flickr"

#       size = ET.SubElement(root, "size")
#       ET.SubElement(size, "width", name="width").text = '486'
#       ET.SubElement(size, "height", name="height").text = '500'
#       ET.SubElement(size, "depth", name="depth").text = '3'

#       ET.SubElement(root, "segmented", name="segmented").text = '0'

#       obj = ET.SubElement(root, "object")
#       ET.SubElement(obj, "name", name="name").text = cat
#       ET.SubElement(obj, "pose", name="pose").text = 'Unspecified'
#       ET.SubElement(obj, "truncated", name="name").text = '0'
#       ET.SubElement(obj, "difficult", name="difficult").text = '0'

#       tree = ET.ElementTree(root)
#       filename = '/content/drive/MyDrive/cs444/project/da-fusion/pascal-aug/pascal-0-8/VOCdevkit/VOC2012/Annotations/'+ name.split('.')[0] + '.xml'
#       print(filename)
#       tree.write(filename)
      

## Define Model

In [10]:
net = resnet18(pretrained=True)
net.avgpool = torch.nn.AdaptiveAvgPool2d(1)
num_ftrs = net.fc.in_features
net.fc = torch.nn.Linear(num_ftrs, num_classes)
net = net.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 87.3MB/s]


## Define Training Parameters

In [11]:
optimizer = optim.SGD([{'params': list(net.parameters())[:-1], 'lr': resnet_lr, 'momentum': 0.9},
                       {'params': list(net.parameters())[-1], 'lr': fc_lr, 'momentum': 0.9}])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 12, eta_min=0, last_epoch=-1)
criterion = torch.nn.BCEWithLogitsLoss(reduction='sum')

In [12]:
def run_test(net, test_loader, criterion):
    correct = 0
    total = 0
    avg_test_loss = 0.0
    l = len(test_loader)
    with torch.no_grad():
        for _, (images, labels) in enumerate(test_loader):
            images, labels = images.to(device), labels.to(device)
            
            outputs = net(images)
            predictions = torch.argmax(outputs, dim=1)
            labels = torch.argmax(labels, dim=1)
            correct += torch.sum(predictions == labels)
            total += labels.size(0)

    print(f'Accuracy of the network on the test images: {100 * correct / total:.2f} %')

In [13]:
def train(net, criterion, optimizer, num_epochs, print_freq = 100):
    for epoch in range(num_epochs):
        running_loss = 0.0
        running_correct = 0.0
        running_total = 0.0
        start_time = time.time()

        net.train()

        for i, (images, labels) in enumerate(train_loader, 0):
            images = images.to(device)
            labels = labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            # Get predicted results
            predicted = torch.argmax(outputs, dim=1)
            labels = torch.argmax(labels, dim=1)

            # print statistics
            running_loss += loss.item()

            # calculate accuracy
            running_total += labels.size(0)
            running_correct += (predicted == labels).sum().item()

            # print every 2000 mini-batches
            if i % print_freq == (print_freq - 1):
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / print_freq:.3f} acc: {100*running_correct / running_total:.2f} time: {time.time() - start_time:.2f}')
                running_loss, running_correct, running_total = 0.0, 0.0, 0.0
                start_time = time.time()

        # Run the run_test() function after each epoch
        net.eval()
        run_test(net, valid_loader, criterion)

In [14]:
import warnings
warnings.filterwarnings('ignore')

In [15]:
train(net, criterion, optimizer, num_epochs=num_epochs)

save_dir = os.path.join(ckpt_dir, 'da-fusion_8.pt')
torch.save(net.state_dict(), save_dir)

[1,   100] loss: 160.830 acc: 13.12 time: 774.38
[1,   200] loss: 123.578 acc: 27.88 time: 752.78
Accuracy of the network on the test images: 46.51 %
[2,   100] loss: 102.454 acc: 46.25 time: 38.14
[2,   200] loss: 92.587 acc: 54.62 time: 39.14
Accuracy of the network on the test images: 64.35 %
[3,   100] loss: 79.701 acc: 62.31 time: 38.17
[3,   200] loss: 76.570 acc: 64.62 time: 39.27
Accuracy of the network on the test images: 67.78 %
[4,   100] loss: 70.331 acc: 67.66 time: 37.71
[4,   200] loss: 67.222 acc: 67.34 time: 38.37
Accuracy of the network on the test images: 68.07 %
[5,   100] loss: 63.378 acc: 70.41 time: 37.59
[5,   200] loss: 60.814 acc: 69.97 time: 39.33
Accuracy of the network on the test images: 68.61 %
[6,   100] loss: 57.874 acc: 72.50 time: 39.05
[6,   200] loss: 56.290 acc: 72.09 time: 38.48
Accuracy of the network on the test images: 68.92 %
[7,   100] loss: 54.609 acc: 72.94 time: 36.82
[7,   200] loss: 53.688 acc: 73.78 time: 37.73
Accuracy of the network o

In [16]:
batch_size = 32
resnet_lr = 1e-6
fc_lr = 5e-4
num_epochs = 25

In [17]:
train(net, criterion, optimizer, num_epochs=num_epochs)

save_dir = os.path.join(ckpt_dir, 'da-fusion_8b.pt')
torch.save(net.state_dict(), save_dir)

[1,   100] loss: 20.884 acc: 82.28 time: 37.21
[1,   200] loss: 20.584 acc: 81.78 time: 37.02
Accuracy of the network on the test images: 73.60 %
[2,   100] loss: 19.172 acc: 82.28 time: 37.35
[2,   200] loss: 20.244 acc: 81.03 time: 38.13
Accuracy of the network on the test images: 73.06 %
[3,   100] loss: 18.173 acc: 82.19 time: 38.71
[3,   200] loss: 18.216 acc: 82.84 time: 38.16
Accuracy of the network on the test images: 73.64 %
[4,   100] loss: 16.958 acc: 82.56 time: 38.19
[4,   200] loss: 18.301 acc: 81.25 time: 37.52
Accuracy of the network on the test images: 73.06 %
[5,   100] loss: 16.168 acc: 82.66 time: 38.41
[5,   200] loss: 16.804 acc: 81.84 time: 36.27
Accuracy of the network on the test images: 73.31 %
[6,   100] loss: 15.094 acc: 82.97 time: 36.72
[6,   200] loss: 15.846 acc: 82.78 time: 37.57
Accuracy of the network on the test images: 73.43 %
[7,   100] loss: 14.704 acc: 82.09 time: 37.11
[7,   200] loss: 14.707 acc: 83.28 time: 37.80
Accuracy of the network on the

KeyboardInterrupt: ignored

In [None]:
train(net, criterion, optimizer, num_epochs=num_epochs)

save_dir = os.path.join(ckpt_dir, 'baseline.pt')
torch.save(net.state_dict(), save_dir)

[1,   100] loss: 162.784 acc: 13.94 time: 2510.78
Accuracy of the network on the test images: 32.70 %
[2,   100] loss: 110.945 acc: 40.19 time: 38.42
Accuracy of the network on the test images: 55.32 %
[3,   100] loss: 90.052 acc: 53.69 time: 38.15
Accuracy of the network on the test images: 62.10 %
[4,   100] loss: 79.339 acc: 59.03 time: 39.18
Accuracy of the network on the test images: 64.28 %
[5,   100] loss: 71.841 acc: 63.97 time: 37.10
Accuracy of the network on the test images: 67.51 %
[6,   100] loss: 66.768 acc: 65.56 time: 37.70
Accuracy of the network on the test images: 68.13 %
[7,   100] loss: 62.867 acc: 67.38 time: 38.22
Accuracy of the network on the test images: 68.21 %
[8,   100] loss: 58.321 acc: 68.59 time: 37.98
Accuracy of the network on the test images: 68.73 %
[9,   100] loss: 55.473 acc: 69.81 time: 38.38
Accuracy of the network on the test images: 69.67 %
[10,   100] loss: 54.706 acc: 69.19 time: 37.57
Accuracy of the network on the test images: 69.74 %
[11, 

In [None]:
base_model = torch.load('/content/drive/MyDrive/cs444/project/da-fusion.pt', map_location = 'cpu')

In [None]:
base_model = resnet18()
base_model.avgpool = torch.nn.AdaptiveAvgPool2d(1)
num_ftrs = base_model.fc.in_features
base_model.fc = torch.nn.Linear(num_ftrs, 20)
base_model = base_model.to(device)

In [None]:
state_dict = torch.load('/content/drive/MyDrive/cs444/project/da-fusion.pt', map_location = 'cpu')

from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
  name = k[11:]
  new_state_dict[name] = v

base_model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [None]:
torch.argmax(base_model(x[0][0].unsqueeze(0)), dim=1)

tensor([11])

In [None]:
run_test(base_model, iter(valid_loader), criterion)

Running: 0.005494505494505495
Running: 0.01098901098901099
Running: 0.016483516483516484
Running: 0.02197802197802198
Running: 0.027472527472527472
Running: 0.03296703296703297
Running: 0.038461538461538464
Running: 0.04395604395604396
Running: 0.04945054945054945
Running: 0.054945054945054944
Running: 0.06043956043956044
Running: 0.06593406593406594
Running: 0.07142857142857142
Running: 0.07692307692307693
Running: 0.08241758241758242
Running: 0.08791208791208792
Running: 0.09340659340659341
Running: 0.0989010989010989
Running: 0.1043956043956044
Running: 0.10989010989010989
Running: 0.11538461538461539
Running: 0.12087912087912088
Running: 0.12637362637362637
Running: 0.13186813186813187
Running: 0.13736263736263737
Running: 0.14285714285714285
Running: 0.14835164835164835
Running: 0.15384615384615385
Running: 0.15934065934065933
Running: 0.16483516483516483
Running: 0.17032967032967034
Running: 0.17582417582417584
Running: 0.1813186813186813
Running: 0.18681318681318682
Running: 0.1