In [None]:
# from google.colab import drive
#
# drive.mount('/content/drive')

### Preparing

In [None]:
# Change working directory to the root of the project
# %cd ...
# !pwd

# Unzip data
# !unzip 2021VRDL_HW1_datasets.zip
# !unzip testing_images.zip -d testing
# !unzip training_images.zip -d training

In [None]:
from torchvision.utils import make_grid, save_image
import os
import random
import time
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

# from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset

In [None]:
LR = 0.001

TRAIN_IMG_DIR = "/content/drive/MyDrive/2021VRDL/HW1/training"
TEST_IMG_DIR = "/content/drive/MyDrive/2021VRDL/HW1/testing"
TRAIN_LABEL_FILE = "/content/drive/MyDrive/2021VRDL/HW1/training_labels.txt"
TEST_FILE = "/content/drive/MyDrive/2021VRDL/HW1/testing_img_order.txt"
CLASSES_FILE = "/content/drive/MyDrive/2021VRDL/HW1/classes.txt"
MODEL_SAVE_DIR = "/content/drive/MyDrive/2021VRDL/HW1/models"
PREDICTION_DIR = "/content/drive/MyDrive/2021VRDL/HW1/prediction"
TRAIN_NAME_FILE = "/content/drive/MyDrive/2021VRDL/HW1/train_filenames.txt"
VALID_NAME_FILE = "/content/drive/MyDrive/2021VRDL/HW1/valid_filenames.txt"
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
N_CLASSES = 200
print("Current device:", DEVICE)

In [None]:
# Load class and make label
CLS_LABEL_TO_NAME = {
    i: c
    for i, c in enumerate(pd.read_csv(
        CLASSES_FILE, names=["class"])["class"].to_list())
}
CLS_NAME_TO_LABEL = {CLS_LABEL_TO_NAME[i]: i for i in CLS_LABEL_TO_NAME}


# Load training data's labels
def read_label_as_dict(name_to_label_file):
    df = pd.read_csv(TRAIN_LABEL_FILE, sep=" ",
                     names=["filename", "class"])
    df.index = df["filename"].tolist()
    return df["class"].to_dict()


TRAIN_NAME_TO_CLS_NAME = read_label_as_dict(TRAIN_LABEL_FILE)
TRAIN_NAME_TO_CLS_LABEL = {
    key: CLS_NAME_TO_LABEL[TRAIN_NAME_TO_CLS_NAME[key]]
    for key in TRAIN_NAME_TO_CLS_NAME
}
TRAIN_NAMES = [key for key in TRAIN_NAME_TO_CLS_LABEL]
TEST_NAMES = pd.read_csv(TEST_FILE, names=["filename"])["filename"].to_list()
TRAIN_X = (
    pd.read_csv(TRAIN_NAME_FILE,
                names=["filename"]).to_numpy().reshape(-1).tolist()
)  # 2600, each class: 13
VALID_X = (
    pd.read_csv(VALID_NAME_FILE,
                names=["filename"]).to_numpy().reshape(-1).tolist()
)  # 400, each class: 2

In [None]:
# Use more data to train
tmp_t_x = TRAIN_X
tmp_v_x = []
for i, tmp_item in enumerate(VALID_X):
    if i % 2 == 0:
        tmp_t_x.append(tmp_item)
    else:
        tmp_v_x.append(tmp_item)
print(len(tmp_t_x), len(tmp_v_x))

2800 200


In [None]:
class BirdDataset(Dataset):
    def __init__(self, img_dir, img_names, img_name_to_label,
                 transform, test=False):
        super().__init__()
        self.img_dir, self.img_names, self.img_name_to_label = (
            img_dir,
            img_names,
            img_name_to_label,
        )
        self.transform = transform
        self.test = test

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_arr = Image.open(
            os.path.join(self.img_dir, self.img_names[idx])).convert("RGB")
        if self.test is True:
            img_tensor = self.transform(img_arr)
            return img_tensor
        else:
            img_label = self.img_name_to_label[self.img_names[idx]]
            img_tensor = self.transform(img_arr)
            return img_tensor, img_label

### Load data

In [None]:
BATCH_SIZE = 8

In [None]:
# Mean: (0.4819434270441397, 0.49756442238151166, 0.4321436327680999)
# Std: (0.1327144718474103, 0.12794266425102524, 0.1728887645128543)

###############


# TRAIN_TRANSFORM = transforms.Compose([transforms.ToTensor(),
#     # transforms.RandomHorizontalFlip(p=0.5),
#     # transforms.RandomRotation(45),
#     # transforms.Resize((100, 100)),
#     transforms.RandomCrop((45, 45)),
# ])
# VALID_TRANSFORM = transforms.Compose([transforms.ToTensor(),
#                                       transforms.Resize((45, 45))])
# TEST_TRANSFORM = transforms.Compose([transforms.ToTensor(),
#                                       transforms.Resize((45, 45))])

# TRAIN_TRANSFORM = transforms.Compose([transforms.ToTensor(),
#     transforms.RandomRotation(30),
#     transforms.RandomHorizontalFlip(p=0.5),
#     transforms.Resize((256, 256)),
#     transforms.RandomCrop((224, 224)),
# ])
# VALID_TRANSFORM = transforms.Compose([transforms.ToTensor(),
#                                       transforms.Resize((224, 224))])
# TEST_TRANSFORM = transforms.Compose([transforms.ToTensor(),
#                                       transforms.Resize((224, 224))])


# TRAIN_TRANSFORM = transforms.Compose([transforms.ToTensor(),
#                                     #   transforms.RandomRotation(30),
#                                       transforms.Resize((400, 400)),
#                                       transforms.RandomRotation(45),
#                                       transforms.RandomCrop((360, 360)),
#                                       # transforms.Resize((224, 224))
#                                      ])
# VALID_TRANSFORM = transforms.Compose([transforms.ToTensor(),
#                                       transforms.Resize((360, 360))])
# TEST_TRANSFORM = transforms.Compose([transforms.ToTensor(),
#                                       transforms.Resize((360, 360))])

TRAIN_TRANSFORM = transforms.Compose(
    [
        transforms.ToTensor(),
        #   transforms.RandomRotation(30),
        transforms.Resize((324, 324)),
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomCrop((300, 300)),
        transforms.Normalize((0.48194, 0.49756, 0.43214),
                             (0.13271, 0.12794, 0.17289))
        # transforms.Resize((224, 224))
    ]
)
VALID_TRANSFORM = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((324, 324)),
        transforms.CenterCrop((300, 300)),
        transforms.Normalize((0.48194, 0.49756, 0.43214),
                             (0.13271, 0.12794, 0.17289)),
    ]
)
TEST_TRANSFORM = transforms.Compose(
    [
        transforms.ToTensor(),
        #   transforms.Resize((300, 300)),
        transforms.Resize((324, 324)),
        #   transforms.RandomRotation(20),
        #   transforms.RandomHorizontalFlip(p=0.5),
        transforms.CenterCrop((300, 300)),
        transforms.Normalize((0.48194, 0.49756, 0.43214),
                             (0.13271, 0.12794, 0.17289)),
    ]
)

# TRAIN_TRANSFORM = transforms.Compose([
#  transforms.ToTensor(),
#  transforms.Resize((256, 256)),
#  transforms.RandomRotation(30),
#  transforms.RandomHorizontalFlip(p=0.5),
#  transforms.RandomCrop((224, 224)),
# ])
# VALID_TRANSFORM = transforms.Compose([transforms.ToTensor(),
#                                       transforms.Resize((224, 224))])
# TEST_TRANSFORM = transforms.Compose([transforms.ToTensor(),
#                                       transforms.Resize((224, 224))])


# TRAIN_X, VALID_X = train_test_split(TRAIN_NAMES, test_size=0.2)
TRAIN_DATASET = BirdDataset(
    img_dir=TRAIN_IMG_DIR,
    img_names=TRAIN_X,
    img_name_to_label=TRAIN_NAME_TO_CLS_LABEL,
    transform=TRAIN_TRANSFORM,
)
VALID_DATASET = BirdDataset(
    img_dir=TRAIN_IMG_DIR,
    img_names=VALID_X,
    img_name_to_label=TRAIN_NAME_TO_CLS_LABEL,
    transform=VALID_TRANSFORM,
)
TEST_DATASET = BirdDataset(
    img_dir=TEST_IMG_DIR,
    img_names=TEST_NAMES,
    img_name_to_label=TRAIN_NAME_TO_CLS_LABEL,
    transform=TEST_TRANSFORM,
    test=True,
)
TRAIN_DATALOADER = DataLoader(TRAIN_DATASET, batch_size=BATCH_SIZE,
                              shuffle=True)
VALID_DATALOADER = DataLoader(VALID_DATASET, batch_size=BATCH_SIZE,
                              shuffle=False)
TEST_DATALOADER = DataLoader(TEST_DATASET, batch_size=BATCH_SIZE,
                             shuffle=False)

In [None]:
# # Install EfficientNet
# !pip install efficientnet_pytorch
# from efficientnet_pytorch import EfficientNet
# MODEL = EfficientNet.from_pretrained('efficientnet-b5')

### Model Class

In [None]:
#
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x


# Bilinear Attention Pooling
class BAP(nn.Module):
    def __init__(self, pool="GAP"):
        super(BAP, self).__init__()
        assert pool in ["GAP"]
        self.pool = None
        self.epsilon = 1e-6

    def forward(self, features, attentions):
        B, C, H, W = features.size()
        _, M, AH, AW = attentions.size()

        # match size
        if AH != H or AW != W:
            attentions = F.upsample_bilinear(attentions, size=(H, W))

        # feature_matrix: (B, M, C) -> (B, M * C)
        feature_matrix = (
            torch.einsum("imjk,injk->imn",
                         (attentions, features)) / float(H * W)
        ).view(B, -1)

        # sign-sqrt
        feature_matrix_raw = torch.sign(feature_matrix) * torch.sqrt(
            torch.abs(feature_matrix) + self.epsilon
        )

        # l2 normalization along dimension M and C
        feature_matrix = F.normalize(feature_matrix_raw, dim=-1)

        if self.training:
            fake_att = torch.zeros_like(attentions).uniform_(0, 2)
        else:
            fake_att = torch.ones_like(attentions)
        counterfactual_feature = (
            torch.einsum("imjk,injk->imn",
                         (fake_att, features)) / float(H * W)
        ).view(B, -1)

        counterfactual_feature = torch.sign(
            counterfactual_feature) * torch.sqrt(
            torch.abs(counterfactual_feature) + self.epsilon
        )

        counterfactual_feature = F.normalize(counterfactual_feature, dim=-1)
        return feature_matrix, counterfactual_feature


class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)


class CAL(nn.Module):
    def __init__(
        self, backbone, backbone_out_feat,
            num_classes, M=32, use_pytorch_resnet=True
    ):
        super(CAL, self).__init__()
        self.use_pytorch_resnet = use_pytorch_resnet
        self.num_classes = num_classes
        self.M = M  # channels of attention maps
        self.epsilon = 1e-6

        # Network Initialization
        self.features = backbone
        self.num_features = backbone_out_feat

        # Attention Maps
        self.attentions = BasicConv2d(self.num_features,
                                      self.M, kernel_size=1)

        # Bilinear Attention Pooling
        self.bap = BAP(pool="GAP")

        # Classification Layer
        self.fc = nn.Linear(self.M * self.num_features,
                            self.num_classes, bias=False)

        print(
            "Using {} as feature extractor, num_classes: {}, "
            "num_attentions: {}".format(
                type(self.features).__name__, self.num_classes, self.M
            )
        )

    def visualize(self, x):

        # Feature Maps, Attention Maps and Feature Matrix
        feature_maps = self.features(x)
        if self.use_pytorch_resnet is True:
            h_w = int(np.sqrt(feature_maps.size(1) // self.num_features))
            feature_maps = feature_maps.view(
                feature_maps.size(0), self.num_features, h_w, h_w
            )
        else:
            feature_maps = feature_maps.view(
                feature_maps.size(0), feature_maps.size(1), 1, 1
            )
        attention_maps = self.attentions(feature_maps)

        feature_matrix = self.bap(feature_maps, attention_maps)
        p = self.fc(feature_matrix * 100.0)

        return p, attention_maps

    def forward(self, x):
        batch_size = x.size(0)

        # Feature Maps, Attention Maps and Feature Matrix
        feature_maps = self.features(x)
        if self.use_pytorch_resnet is True:
            h_w = int(np.sqrt(feature_maps.size(1) // self.num_features))
            feature_maps = feature_maps.view(
                feature_maps.size(0), self.num_features, h_w, h_w
            )
        else:
            feature_maps = feature_maps.view(
                feature_maps.size(0), feature_maps.size(1), 1, 1
            )

        attention_maps = self.attentions(feature_maps)

        feature_matrix, feature_matrix_hat = self.bap(
            feature_maps, attention_maps)

        # Classification
        p = self.fc(feature_matrix * 100.0)

        # Generate Attention Map
        if self.training:
            # Randomly choose one of attention maps Ak
            attention_map = []
            for i in range(batch_size):
                tmp_map = attention_maps[i] - min(attention_maps[i].min(), 0)
                attention_weights = torch.sqrt(
                    tmp_map.sum(dim=(1, 2)).detach() + self.epsilon
                )
                attention_weights = F.normalize(attention_weights, p=1, dim=0)
                k_index = np.random.choice(self.M, 2,
                                           p=attention_weights.cpu().numpy())
                attention_map.append(attention_maps[i, k_index, ...])
            attention_map = torch.stack(
                attention_map
            )  # (B, 2, H, W) - one for cropping, the other for dropping
        else:
            attention_map = torch.mean(
                attention_maps, dim=1, keepdim=True
            )  # (B, 1, H, W)

        return (p, p - self.fc(feature_matrix_hat * 100.0),
                feature_matrix, attention_map)

### Train function

In [None]:
def batch_augment(images, attention_map,
                  mode="crop", theta=0.5, padding_ratio=0.1):
    batches, _, imgH, imgW = images.size()

    if mode == "crop":
        crop_images = []
        for batch_index in range(batches):
            atten_map = attention_map[batch_index:batch_index + 1]
            if isinstance(theta, tuple):
                theta_c = random.uniform(*theta) * atten_map.max()
            else:
                theta_c = theta * atten_map.max()

            crop_mask = F.upsample_bilinear(atten_map,
                                            size=(imgH, imgW)) >= theta_c
            nonzero_indices = torch.nonzero(crop_mask[0, 0, ...])

            # Special case: find no nonzero !!!!!!!!!!!!!!!!!!!!!!!!!
            if nonzero_indices[:, 0].numel() == 0:
                crop_mask_h_siz = crop_mask.size(2)
                crop_mask_w_siz = crop_mask.size(3)
                assert crop_mask_h_siz == crop_mask_w_siz
                crop_mask_siz_minus_one = crop_mask_h_siz - 1
                nonzero_indices = torch.tensor(
                    [
                        [0, 0],
                        [0, crop_mask_siz_minus_one],
                        [crop_mask_siz_minus_one, 0],
                        [crop_mask_siz_minus_one, crop_mask_siz_minus_one],
                    ]
                )

            height_min = max(
                int(nonzero_indices[:, 0].min().item() -
                    padding_ratio * imgH), 0
            )
            height_max = min(
                int(nonzero_indices[:, 0].max().item() +
                    padding_ratio * imgH), imgH
            )
            width_min = max(
                int(nonzero_indices[:, 1].min().item() -
                    padding_ratio * imgW), 0
            )
            width_max = min(
                int(nonzero_indices[:, 1].max().item() +
                    padding_ratio * imgW), imgW
            )

            crop_images.append(
                F.upsample_bilinear(
                    images[
                        batch_index:batch_index + 1,
                        :,
                        height_min:height_max,
                        width_min:width_max,
                    ],
                    size=(imgH, imgW),
                )
            )
        crop_images = torch.cat(crop_images, dim=0)
        return crop_images

    elif mode == "drop":
        drop_masks = []
        for batch_index in range(batches):
            atten_map = attention_map[batch_index:batch_index + 1]
            if isinstance(theta, tuple):
                theta_d = random.uniform(*theta) * atten_map.max()
            else:
                theta_d = theta * atten_map.max()

            drop_masks.append(
                F.upsample_bilinear(atten_map, size=(imgH, imgW)) < theta_d
            )
        drop_masks = torch.cat(drop_masks, dim=0)
        drop_images = images * drop_masks.float()
        return drop_images

    else:
        raise ValueError(
            "Expected mode in ['crop', 'drop'], "
            "but received unsupported augmentation method %s"
            % mode
        )

In [None]:
def eval_model(model, test_loader, device):
    correct = 0.0
    total = 0.0
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)

            y_pred_raw, y_pred_aux, _, attention_map = model(images)

            # crop_images3 = batch_augment(
            #     images, attention_map, mode="crop", theta=0.1,
            #     padding_ratio=0.05
            # )
            # y_pred_crop3, y_pred_aux_crop3, _, _ = model(crop_images3)
            #
            # y_pred = (y_pred_raw + y_pred_crop3) / 2.0
            # # y_pred_aux = (y_pred_aux + y_pred_aux_crop3) / 2.

            _, predicted = torch.max(y_pred_raw.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    test_acc = correct / total
    print(f"Accuracy of the network on the test images: {test_acc:.4f}")
    return test_acc

In [None]:
class CenterLoss(nn.Module):
    def __init__(self):
        super(CenterLoss, self).__init__()
        self.l2_loss = nn.MSELoss(reduction="sum")

    def forward(self, outputs, targets):
        return self.l2_loss(outputs, targets) / outputs.size(0)


def train_model(
    model,
    batch_size,
    criterion,
    optimizer,
    scheduler,
    n_epochs,
    device,
    train_loader,
    test_loader,
    model_save_dir,
    config_beta,
    previous_record: dict = None,
    no_valid=False,
):
    print(f'Batch size: {batch_size}, criterion: {criterion}')
    if previous_record is not None:
        losses = previous_record["losses"]
        accuracies = previous_record["accuracies"]
        test_accuracies = previous_record["test_accuracies"]
        feature_center = previous_record["feature_center"]
    else:
        losses = []
        accuracies = []
        test_accuracies = []
        feature_center = torch.zeros(
            model.num_classes, model.M * model.num_features
        ).to(device)

    cross_entropy_loss = nn.CrossEntropyLoss()
    center_loss = CenterLoss()
    # set the model to train mode initially
    model.train()
    for epoch in range(n_epochs):
        print(
            "=== Epoch {} (lr={}) === ".format(
                epoch + 1, optimizer.param_groups[0]["lr"]
            ),
            end="",
        )
        since = time.time()
        running_loss = 0.0
        running_raw_loss = 0.0
        running_aux_loss = 0.0
        running_aug_loss = 0.0
        running_feat_mat_loss = 0.0
        running_correct = 0.0
        for i, data in enumerate(train_loader, 0):
            if i % 5 == 0:
                print(f"{i} ", end="")
            # get the inputs and assign them to cuda
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()

            #######################################################

            y = labels
            (y_pred_raw,
             y_pred_aux,
             feature_matrix, attention_map) = model(inputs)

            # Update Feature Center
            feature_center_batch = F.normalize(feature_center[y], dim=-1)
            feature_center[y] += config_beta * (
                feature_matrix.detach() - feature_center_batch
            )

            # Attention Cropping
            with torch.no_grad():
                crop_images = batch_augment(
                    inputs,
                    attention_map[:, :1, :, :],
                    mode="crop",
                    theta=(0.4, 0.6),
                    padding_ratio=0.1,
                )
                drop_images = batch_augment(
                    inputs, attention_map[:, 1:, :, :], mode="drop",
                    theta=(0.2, 0.5)
                )
            aug_images = torch.cat([crop_images, drop_images], dim=0)
            y_aug = torch.cat([y, y], dim=0)

            # crop images forward
            y_pred_aug, y_pred_aux_aug, _, _ = model(aug_images)

            y_pred_aux = torch.cat([y_pred_aux, y_pred_aux_aug], dim=0)
            y_aux = torch.cat([y, y_aug], dim=0)

            # loss
            raw_loss = cross_entropy_loss(y_pred_raw, y)
            aux_loss = cross_entropy_loss(y_pred_aux, y_aux)
            aug_loss = cross_entropy_loss(y_pred_aug, y_aug)
            feat_mat_loss = center_loss(feature_matrix, feature_center_batch)
            loss = (
                raw_loss / 3.0
                + aux_loss * 3.0 / 3.0
                + aug_loss * 2.0 / 3.0
                + feat_mat_loss
            )
            # loss = raw_loss / 3. + aux_loss  * 5. / 3. + \
            # aug_loss * 4. / 3. + feat_mat_loss

            # backward
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(y_pred_raw.data, 1)

            # # forward + backward + optimize
            # outputs = model(inputs)
            # _, predicted = torch.max(outputs.data, 1)
            # loss = criterion(outputs, labels)
            # loss.backward()
            # optimizer.step()

            # calculate the loss/acc later
            running_loss += loss.item()
            running_raw_loss += raw_loss.item()
            running_aux_loss += aux_loss.item()
            running_aug_loss += aug_loss.item()
            running_feat_mat_loss += feat_mat_loss.item()
            running_correct += (labels == predicted).sum().item()
        print()

        epoch_duration = time.time() - since
        epoch_loss = running_loss / len(train_loader)
        epoch_raw_loss = running_raw_loss / len(train_loader)
        epoch_aux_loss = running_aux_loss / len(train_loader)
        epoch_aug_loss = running_aug_loss / len(train_loader)
        epoch_feat_mat_loss = running_feat_mat_loss / len(train_loader)
        epoch_acc = running_correct / len(train_loader.dataset)
        print(
            "(%d s) (loss: %.3f/%.3f/%.3f/%.3f) (acc: %.4f)"
            % (
                epoch_duration,
                epoch_raw_loss,
                epoch_aux_loss,
                epoch_aug_loss,
                epoch_feat_mat_loss,
                epoch_acc,
            )
        )

        losses.append(epoch_loss)
        accuracies.append(epoch_acc)

        if no_valid is not True:
            # switch the model to eval mode to evaluate on test data
            model.eval()
            test_acc = eval_model(model, test_loader, device)
            test_accuracies.append(test_acc)

        # re-set the model to train mode after validating
        model.train()
        scheduler.step(test_acc)

        model_save_name = "{}_{:.4f}.pt".format(
            datetime.now().strftime("UTC+8_%Y_%m-%d_%H:%M"), test_acc
        )
        model_save_path = os.path.join(model_save_dir, model_save_name)
        # torch.save(model, model_save_path)
        torch.save(
            {
                "model": model,
                "losses": losses,
                "accuracies": accuracies,
                "test_accuracies": test_accuracies,
                "feature_center": feature_center,
            },
            model_save_path,
        )

    print("Finished Training")
    return model, losses, accuracies, test_accuracies

### Only resnet train func

In [None]:
def eval_model_not_cal(model, test_loader, device):
    correct = 0.0
    total = 0.0
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)

            y_pred = model(images)
            # y_pred_aux = (y_pred_aux + y_pred_aux_crop3) / 2.

            _, predicted = torch.max(y_pred.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    test_acc = correct / total
    print(f"Accuracy of the network on the test images: {test_acc:.4f}")
    return test_acc


def train_model_not_cal(
    model,
    batch_size,
    criterion,
    optimizer,
    scheduler,
    n_epochs,
    device,
    train_loader,
    test_loader,
    model_save_dir,
    config_beta,
    previous_record: dict = None,
    no_valid=False,
):
    print(f'Batch size: {batch_size}, config_beta: {config_beta}')
    if previous_record is not None:

        losses = previous_record["losses"]
        accuracies = previous_record["accuracies"]
        test_accuracies = previous_record["test_accuracies"]
    else:
        losses = []
        accuracies = []
        test_accuracies = []

    # set the model to train mode initially
    model.train()
    for epoch in range(n_epochs):
        print(
            "=== Epoch {} (lr={}) === ".format(
                epoch + 1, optimizer.param_groups[0]["lr"]
            ),
            end="",
        )
        since = time.time()
        running_loss = 0.0
        # running_raw_loss = 0.0
        # running_aux_loss = 0.0
        # running_aug_loss = 0.0
        # running_feat_mat_loss = 0.0
        running_correct = 0.0
        for i, data in enumerate(train_loader, 0):
            if i % 5 == 0:
                print(f"{i} ", end="")
            # get the inputs and assign them to cuda
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()

            #######################################################

            # forward + backward + optimize
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # calculate the loss/acc later
            running_loss += loss.item()
            running_correct += (labels == predicted).sum().item()
        print()

        epoch_duration = time.time() - since
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = running_correct / len(train_loader.dataset)
        print(
            "(%d s) (loss: %.3f) (acc: %.4f)" % (epoch_duration,
                                                 epoch_loss, epoch_acc)
        )

        losses.append(epoch_loss)
        accuracies.append(epoch_acc)

        if no_valid is not True:
            # switch the model to eval mode to evaluate on test data
            model.eval()
            test_acc = eval_model_not_cal(model, test_loader, device)
            test_accuracies.append(test_acc)

        # re-set the model to train mode after validating
        model.train()
        scheduler.step(test_acc)
        since = time.time()

        model_save_name = "{}_{:.4f}.pt".format(
            datetime.now().strftime("UTC+8_%Y_%m-%d_%H:%M"), test_acc
        )
        model_save_path = os.path.join(model_save_dir, model_save_name)
        # torch.save(model, model_save_path)
        torch.save(
            {
                "model": model,
                "losses": losses,
                "accuracies": accuracies,
                "test_accuracies": test_accuracies,
            },
            model_save_path,
        )

    print("Finished Training")
    return model, losses, accuracies, test_accuracies

# Training setting

### Train from zero with CAL & WS-DAN

In [None]:
# # Change backbone (effecient-net)
# BACKBONE = EfficientNet.from_pretrained('efficientnet-b6')
# BACKBONE_OUT_FEAT = BACKBONE._fc.in_features
# BACKBONE._fc = Identity() # Remove
# BACKBONE._swish = Identity() # Remove backbone's classifier
# MODEL = CAL(backbone=BACKBONE, backbone_out_feat=BACKBONE_OUT_FEAT,
#             num_classes=N_CLASSES, M=16, use_pytorch_resnet=False)
# MODEL = MODEL.to(DEVICE)

# Change backbone (ResNet)
BACKBONE = models.resnet152(pretrained=True)
BACKBONE_OUT_FEAT = BACKBONE.fc.in_features
BACKBONE.avgpool = Identity()  # Remove
BACKBONE.fc = Identity()  # Remove backbone's classifier
MODEL = CAL(
    backbone=BACKBONE,
    backbone_out_feat=BACKBONE_OUT_FEAT,
    num_classes=N_CLASSES,
    M=16,
    use_pytorch_resnet=True,
)

MODEL = MODEL.to(DEVICE)
CRITERION = nn.CrossEntropyLoss()
OPTIMIZER = optim.SGD(MODEL.parameters(), lr=LR, momentum=0.9)
SCHEDULER = optim.lr_scheduler.ReduceLROnPlateau(
    OPTIMIZER, mode="max", patience=5, threshold=0.9
)

Using ResNet as feature extractor, num_classes: 200, num_attentions: 16


### Train from zero by only ResNet

In [None]:
MODEL = models.resnet152(pretrained=True)
MODEL = MODEL.to(DEVICE)
BACKBONE_OUT_FEAT = MODEL.fc.in_features
CRITERION = nn.CrossEntropyLoss()
OPTIMIZER = optim.SGD(MODEL.parameters(), lr=LR, momentum=0.9)
SCHEDULER = optim.lr_scheduler.ReduceLROnPlateau(
    OPTIMIZER, mode="max", patience=4, threshold=0.9
)

### Train a stored model (CAREFUL)

In [None]:
PREVIOUS_RECORD = torch.load(
    "/content/drive/MyDrive/2021VRDL/HW1/models/"
    "UTC+8_2021_11-01_19:22_0.8950.pt"
)
MODEL = PREVIOUS_RECORD["model"]
MODEL = MODEL.to(DEVICE)
CRITERION = nn.CrossEntropyLoss()
OPTIMIZER = optim.SGD(MODEL.parameters(), lr=LR, momentum=0.9)
SCHEDULER = optim.lr_scheduler.ReduceLROnPlateau(
    OPTIMIZER, mode="max", patience=4, threshold=0.9
)

# Train

### Train

In [None]:
MODEL, train_losses, train_accs, valid_accs = train_model(
    batch_size=BATCH_SIZE,
    model=MODEL,
    previous_record=None,
    # previous_record=PREVIOUS_RECORD,
    criterion=CRITERION,
    optimizer=OPTIMIZER,
    scheduler=SCHEDULER,
    n_epochs=100,
    device=DEVICE,
    train_loader=TRAIN_DATALOADER,
    test_loader=VALID_DATALOADER,
    model_save_dir=MODEL_SAVE_DIR,
    config_beta=5e-2,
)

### Change LR and nvidia-smi

In [None]:
def change_lr(optimizer, lr):
    for g in optimizer.param_groups:
        g["lr"] = lr


change_lr(OPTIMIZER, lr=0.0001)

In [None]:
!nvidia-smi

# Predict & other tools

### Pred

In [None]:
def pred_images(model, test_loader, device):
    all_predicted = []
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            print(f"Iter: {i}/{len(test_loader)}")
            images = data
            images = images.to(device)

            # y_pred_raw, y_pred_aux, _, attention_map = model(images)
            # crop_images3 = batch_augment(images,
            #                              attention_map, mode='crop',
            #                              theta=0.1, padding_ratio=0.05)
            # y_pred_crop3, y_pred_aux_crop3, _, _ = model(crop_images3)
            # y_pred = (y_pred_raw + y_pred_crop3) / 2.

            y_pred_raw, y_pred_aux, _, attention_map = model(images)
            crop_images = batch_augment(
                images, attention_map, mode="crop",
                theta=0.3, padding_ratio=0.1
            )
            y_pred_crop, y_pred_aux_crop, _, _ = model(crop_images)
            crop_images2 = batch_augment(
                images, attention_map, mode="crop",
                theta=0.2, padding_ratio=0.1
            )
            y_pred_crop2, y_pred_aux_crop2, _, _ = model(crop_images2)
            crop_images3 = batch_augment(
                images, attention_map, mode="crop",
                theta=0.1, padding_ratio=0.05
            )
            y_pred_crop3, y_pred_aux_crop3, _, _ = model(crop_images3)
            y_pred = (y_pred_raw + y_pred_crop + \
                      y_pred_crop2 + y_pred_crop3) / 4.0

            _, predicted = torch.max(y_pred.data, 1)
            all_predicted += predicted.tolist()
    return all_predicted


def pred_images_not_cal(model, test_loader, device):
    all_predicted = []
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            print(f"Iter: {i}/{len(test_loader)}")
            images = data
            images = images.to(device)

            y_pred = model(images)

            _, predicted = torch.max(y_pred.data, 1)
            all_predicted += predicted.tolist()
    return all_predicted

In [None]:
predict = pred_images_not_cal(MODEL, TEST_DATALOADER, DEVICE)
with open(
    os.path.join(
        PREDICTION_DIR, "{}.txt".format(
            datetime.now().strftime("UTC+8_%Y_%m-%d_%H:%M"))
    ),
    "w+",
) as f:
    content = "\n".join(
        [f"{TEST_NAMES[i]} {CLS_LABEL_TO_NAME[item]}"
         for i, item in enumerate(predict)]
    )
    f.write(content)
print("Done!")

In [None]:
predict = pred_images(MODEL, TEST_DATALOADER, DEVICE)
with open(
    os.path.join(
        PREDICTION_DIR, "{}.txt".format(
            datetime.now().strftime("UTC+8_%Y_%m-%d_%H:%M"))
    ),
    "w+",
) as f:
    content = "\n".join(
        [f"{TEST_NAMES[i]} {CLS_LABEL_TO_NAME[item]}"
         for i, item in enumerate(predict)]
    )
    f.write(content)
print("Done!")

### Plot attention

In [None]:
t = None
for i, (inputs) in enumerate(TEST_DATALOADER):
    if i < 3:
        continue
    _, c, h, w = inputs.shape
    # t = inputs[0].view(1, c, h, w)
    t = inputs
    break
t = t.to(DEVICE)


def visualize(model, x):

    # Feature Maps, Attention Maps and Feature Matrix
    feature_maps = model.features(x)
    if model.use_pytorch_resnet is True:
        h_w = int(np.sqrt(feature_maps.size(1) // model.num_features))
        feature_maps = feature_maps.view(
            feature_maps.size(0), model.num_features, h_w, h_w
        )
    attention_maps = model.attentions(feature_maps)

    feature_matrix = model.bap(feature_maps, attention_maps)
    p = model.fc(feature_matrix[0] * 100.0)

    return p, attention_maps


ret_p, ret_atten_maps = visualize(x=t, model=MODEL)


def generate_heatmap(attention_maps):
    heat_attention_maps = []
    heat_attention_maps.append(attention_maps[:, 0, ...])  # R
    heat_attention_maps.append(
        attention_maps[:, 0, ...] * (attention_maps[:, 0, ...] < 0.5).float()
        + (1.0 - attention_maps[:, 0, ...]) * (
                attention_maps[:, 0, ...] >= 0.5).float()
    )  # G
    heat_attention_maps.append(1.0 - attention_maps[:, 0, ...])  # B
    return torch.stack(heat_attention_maps, dim=1)


ToPILImage = transforms.ToPILImage()
MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)

with torch.no_grad():
    attention_maps = torch.max(ret_atten_maps, dim=1, keepdim=True)[0]
    attention_maps = F.upsample_bilinear(attention_maps, size=(t.size(2),
                                                               t.size(3)))
    attention_maps = torch.sqrt(
        attention_maps.cpu() / attention_maps.max().item())

    heat_attention_maps = generate_heatmap(attention_maps)

    # raw_image = t.cpu() * STD + MEAN
    raw_image = t.cpu()
    heat_attention_image = raw_image * 0.25 + heat_attention_maps * 0.75
    raw_attention_image = raw_image * attention_maps

    img_concat = None
    for batch_idx in range(t.size(0)):
        # rimg = ToPILImage(raw_image[batch_idx])
        # haimg = ToPILImage(heat_attention_image[batch_idx])
        savepath = "/content/sample_data"

        if img_concat is None:
            img_concat = raw_image[batch_idx].unsqueeze(dim=0)
        else:
            img_concat = torch.cat((img_concat,
                                    raw_image[batch_idx].unsqueeze(dim=0)))
        img_concat = torch.cat(
            (img_concat, heat_attention_image[batch_idx].unsqueeze(dim=0))
        )
    grid = make_grid(img_concat)
    save_image(grid, os.path.join(savepath, "grid.png"))

### Show cropping examples


In [None]:
t = None
for i, (inputs) in enumerate(TEST_DATALOADER):
    if i < 3:
        continue
    _, c, h, w = inputs.shape
    t = inputs
    break
_, _, _, atm = MODEL(t.to(DEVICE))
crop_images = batch_augment(t, atm, mode="drop", theta=0.6, padding_ratio=0.1)


grid = make_grid(torch.cat((t, crop_images)))
arr = grid.cpu().numpy()
plt.figure(figsize=(20, 10))
plt.imshow(arr.transpose(1, 2, 0))
plt.show()