In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import precision_score, recall_score, f1_score

from PIL import Image
import io, pickle
from tqdm import tqdm, trange
import numpy as np
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open('images/images.pickle', 'rb') as f:
    img_data = pickle.load(f)
class CustomDataset(Dataset):
    def __init__(self, img_data, txt_file, transform=None):
        self.data = img_data
        with open(txt_file, 'r') as f:
            self.labels = f.readlines()
        self.transform = transform
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        img_path, label = self.labels[idx].strip().split()
        img_path = img_path.split('/')
        img = self.data[img_path[1]][img_path[2]]
        img = Image.open(io.BytesIO(img)).convert('RGB')
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(int(label))
        return img, label


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset_train = CustomDataset(img_data=img_data, txt_file='images/train.txt', transform=transform)
dataset_val = CustomDataset(img_data=img_data, txt_file='images/val.txt', transform=transform)
dataset_test = CustomDataset(img_data=img_data, txt_file='images/test.txt', transform=transform)


data_loader_train = DataLoader(dataset_train, batch_size=32, shuffle=True, )
data_loader_val = DataLoader(dataset_val, batch_size=32, shuffle=False)
data_loader_test = DataLoader(dataset_test, batch_size=32, shuffle=False)


In [9]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        batch_size, C, width, height = x.size()
        
        # 计算查询、键和值
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        
        # 计算注意力权重并应用到值上
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        
        # 调整输出并加上残差连接
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

class CNNModel(nn.Module):
    def __init__(self, num_classes=50):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.attention1 = SelfAttention(128)

        self.fc1 = nn.Linear(128*56*56, 1024)
        # self.fc2 = nn.Linear(4096, 512)
        # self.fc3 = nn.Linear(512, num_classes)
        self.fc2 = nn.Linear(1024, num_classes)
        self.activation = nn.ReLU()
        self.flatten = nn.Flatten()
        self.pool = nn.MaxPool2d(2, 2)
        # self.dropout25 = nn.Dropout(0.25)
        self.dropout50 = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.pool(self.activation(self.bn1(self.conv1(x))))
        x = self.pool(self.activation(self.bn2(self.conv2(x))))
        x = self.attention1(x)
        x = self.dropout50(self.flatten(x))
        x = self.dropout50(self.activation(self.fc1(x)))
        # x = self.dropout50(self.activation(self.fc2(x)))
        # x = self.fc3(x)
        x = self.fc2(x)
        return x


for lr in [1e-6]:
    model = CNNModel(num_classes=50).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    num_epochs = 10
    save_path = 'record/Q2_less_layers'
    model_path = f'{save_path}/model_weight/model_weight_lr{lr}_bn_dropout50.pth'
    best_val_loss = float('inf')
    best_model = None
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        pbar = tqdm(data_loader_train, colour='green', total=len(data_loader_train))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Train Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()
        running_loss_avg = running_loss / len(dataset_train)
        running_acc_avg = running_acc / len(dataset_train)
        train_loss.append(running_loss_avg)
        train_acc.append(running_acc_avg)
        print(f"Epoch {epoch+1}, Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")

        model.eval()
        with torch.no_grad():
            class_counts = [0] * 50
            running_loss = 0.0
            running_acc = 0.0
            pbar = tqdm(data_loader_val, colour='red', total=len(data_loader_val))
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predictions = outputs.argmax(dim=1)
                loss = criterion(outputs, labels)
                running_loss += loss.item()*len(labels)
                running_acc += (predictions == labels).float().sum().item()
                pbar.set_description(f'Valid Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
                pbar.update()

                for pred_label in predictions:
                    class_counts[pred_label] += 1

            running_loss_avg = running_loss / len(dataset_val)
            running_acc_avg = running_acc / len(dataset_val)
            val_loss.append(running_loss_avg)
            val_acc.append(running_acc_avg)
            print(f"Validation Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")
            print(f"Class Counts: {class_counts}")
            if running_loss_avg < best_val_loss:
                best_val_loss = running_loss
                best_model = model
                torch.save(model.state_dict(), model_path)
        print("======================================")
    print("Finished Training")


    model.eval()
    with torch.no_grad():
        class_counts = [0] * 50
        running_loss = 0.0
        running_acc = 0.0
        predictions = []
        true_labels = []
        pbar = tqdm(data_loader_test, colour='red', total=len(data_loader_test))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs, 1)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            loss = criterion(outputs, labels)
            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Test Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()

        for pred_label in predictions:
            class_counts[pred_label] += 1
        
        precision = precision_score(true_labels, predictions, average='weighted')
        recall = recall_score(true_labels, predictions, average='weighted')
        f1 = f1_score(true_labels, predictions, average='weighted')
        print(f"Test Loss: {running_loss/len(dataset_test)}, Accuracy: {running_acc/len(dataset_test)}")
        print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
        print(f"Class Counts: {class_counts}")
    print("Finished Testing")

    train_record = {"train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc}
    test_record = {"true_labels": true_labels, "predictions": predictions, "precision": precision, "recall": recall, "f1": f1, "class_counts": class_counts}


    with open(f'{save_path}/train_record/train_record_lr{lr}_bn_dropout50.pkl', 'wb') as f:
        pickle.dump(train_record, f)
    with open(f'{save_path}/test_record/test_record_lr{lr}_bn_dropout50.pkl', 'wb') as f:
        pickle.dump(test_record, f)


Train Iter: 001/010  Loss: 3.2567: 100%|[32m██████████[0m| 1979/1979 [04:23<00:00,  7.51it/s]


Epoch 1, Loss: 3.5076264987159425, Accuracy: 0.10776154757204895


Valid Iter: 001/010  Loss: 2.4816: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 22.42it/s]


Validation Loss: 3.167535139719645, Accuracy: 0.1622222222222222
Class Counts: [12, 8, 44, 9, 23, 5, 6, 23, 1, 1, 0, 1, 0, 5, 3, 5, 0, 4, 13, 27, 13, 8, 4, 7, 7, 0, 0, 14, 1, 5, 2, 0, 4, 24, 13, 1, 25, 7, 3, 8, 13, 10, 14, 2, 2, 12, 37, 3, 16, 5]


Train Iter: 002/010  Loss: 3.2722: 100%|[32m██████████[0m| 1979/1979 [04:24<00:00,  7.47it/s]


Epoch 2, Loss: 3.1977211644119583, Accuracy: 0.16341097512830635


Valid Iter: 002/010  Loss: 1.8361: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 22.18it/s]


Validation Loss: 3.014883509741889, Accuracy: 0.20222222222222222
Class Counts: [4, 15, 9, 14, 18, 10, 7, 0, 1, 3, 0, 0, 0, 11, 5, 17, 0, 7, 13, 10, 15, 2, 3, 15, 0, 13, 0, 12, 19, 4, 8, 3, 4, 14, 30, 7, 3, 12, 8, 11, 18, 16, 2, 22, 10, 6, 15, 2, 18, 14]


Train Iter: 003/010  Loss: 3.1025: 100%|[32m██████████[0m| 1979/1979 [04:24<00:00,  7.49it/s]


Epoch 3, Loss: 3.0427990581619375, Accuracy: 0.1962258191867351


Valid Iter: 003/010  Loss: 1.6249: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 22.56it/s]


Validation Loss: 2.915509593221876, Accuracy: 0.2288888888888889
Class Counts: [11, 10, 19, 24, 13, 10, 5, 5, 2, 4, 0, 1, 3, 13, 14, 10, 4, 9, 12, 5, 13, 8, 2, 7, 13, 7, 1, 5, 10, 5, 11, 2, 2, 24, 20, 5, 11, 8, 10, 15, 12, 6, 8, 4, 6, 18, 11, 5, 13, 14]


Train Iter: 004/010  Loss: 2.6851: 100%|[32m██████████[0m| 1979/1979 [04:24<00:00,  7.48it/s]


Epoch 4, Loss: 2.930684765910989, Accuracy: 0.22146071851559415


Valid Iter: 004/010  Loss: 1.3721: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 22.28it/s]


Validation Loss: 2.86415381749471, Accuracy: 0.2288888888888889
Class Counts: [5, 9, 25, 4, 18, 14, 7, 6, 8, 3, 6, 0, 3, 10, 4, 4, 21, 9, 3, 11, 7, 10, 18, 4, 5, 17, 1, 6, 12, 3, 3, 1, 5, 15, 26, 9, 3, 9, 8, 18, 15, 5, 2, 16, 4, 8, 13, 2, 17, 18]


Train Iter: 005/010  Loss: 2.8861: 100%|[32m██████████[0m| 1979/1979 [04:24<00:00,  7.47it/s]


Epoch 5, Loss: 2.826700122018506, Accuracy: 0.24873272799052507


Valid Iter: 005/010  Loss: 1.1240: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 22.07it/s]


Validation Loss: 2.83906281153361, Accuracy: 0.24
Class Counts: [6, 8, 13, 5, 20, 5, 11, 2, 5, 2, 1, 0, 4, 9, 18, 13, 6, 7, 23, 16, 9, 3, 14, 6, 0, 1, 0, 5, 11, 25, 1, 1, 7, 21, 26, 9, 2, 7, 5, 17, 16, 5, 7, 15, 3, 7, 8, 3, 18, 24]


Train Iter: 006/010  Loss: 2.3610: 100%|[32m██████████[0m| 1979/1979 [04:23<00:00,  7.51it/s]


Epoch 6, Loss: 2.7301098564891912, Accuracy: 0.27166206079747335


Valid Iter: 006/010  Loss: 0.8732: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 22.35it/s]


Validation Loss: 2.796951734224955, Accuracy: 0.24
Class Counts: [6, 13, 18, 9, 23, 3, 10, 12, 8, 3, 2, 0, 9, 3, 11, 6, 14, 10, 6, 14, 10, 3, 17, 5, 8, 8, 2, 5, 10, 9, 6, 1, 7, 10, 19, 8, 16, 6, 8, 2, 15, 12, 4, 15, 5, 6, 10, 3, 20, 20]


Train Iter: 007/010  Loss: 2.7471: 100%|[32m██████████[0m| 1979/1979 [04:25<00:00,  7.45it/s]


Epoch 7, Loss: 2.6398464951701563, Accuracy: 0.2934228187919463


Valid Iter: 007/010  Loss: 1.1545: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 21.03it/s]


Validation Loss: 2.7809141688876684, Accuracy: 0.24222222222222223
Class Counts: [8, 10, 25, 4, 28, 8, 4, 8, 2, 6, 2, 11, 4, 11, 10, 11, 11, 4, 13, 10, 5, 8, 16, 7, 0, 4, 3, 4, 10, 13, 0, 2, 14, 12, 8, 10, 11, 6, 8, 4, 13, 12, 1, 27, 7, 13, 7, 2, 13, 20]


Train Iter: 008/010  Loss: 2.4822: 100%|[32m██████████[0m| 1979/1979 [04:23<00:00,  7.51it/s]


Epoch 8, Loss: 2.5535960705234184, Accuracy: 0.3151204105803395


Valid Iter: 008/010  Loss: 1.4317: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 22.11it/s]


Validation Loss: 2.7566209660636054, Accuracy: 0.26222222222222225
Class Counts: [8, 14, 16, 5, 22, 5, 8, 6, 4, 2, 6, 0, 2, 5, 6, 3, 7, 5, 9, 7, 17, 7, 7, 5, 7, 7, 4, 5, 19, 5, 7, 9, 11, 10, 9, 10, 16, 13, 10, 21, 16, 10, 10, 19, 6, 7, 11, 5, 17, 10]


Train Iter: 009/010  Loss: 2.7083: 100%|[32m██████████[0m| 1979/1979 [04:24<00:00,  7.47it/s]


Epoch 9, Loss: 2.470270771733678, Accuracy: 0.33572838531385707


Valid Iter: 009/010  Loss: 1.1048: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 20.23it/s]


Validation Loss: 2.72089382648468, Accuracy: 0.2577777777777778
Class Counts: [3, 9, 7, 6, 21, 8, 9, 8, 3, 3, 2, 4, 6, 4, 4, 12, 6, 11, 12, 14, 9, 5, 3, 23, 1, 6, 2, 9, 15, 13, 4, 6, 7, 17, 17, 12, 7, 13, 6, 12, 11, 14, 6, 23, 9, 10, 10, 5, 11, 12]


Train Iter: 010/010  Loss: 2.5836: 100%|[32m██████████[0m| 1979/1979 [04:24<00:00,  7.48it/s]


Epoch 10, Loss: 2.385075069880043, Accuracy: 0.35946308724832216


Valid Iter: 010/010  Loss: 1.1329: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 22.26it/s]


Validation Loss: 2.725719035466512, Accuracy: 0.24888888888888888
Class Counts: [5, 2, 17, 4, 22, 16, 8, 10, 7, 7, 3, 0, 5, 8, 11, 8, 3, 5, 14, 5, 32, 4, 7, 8, 3, 9, 11, 2, 18, 2, 4, 14, 13, 8, 9, 3, 7, 15, 6, 11, 16, 11, 3, 19, 7, 12, 11, 2, 12, 11]
Finished Training


Test Iter: 010/010  Loss: 1.1329: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 22.40it/s]


Test Loss: 1.1328809261322021, Accuracy: 0.28
Precision: 0.29624302878139197, Recall: 0.28, F1: 0.262235022776772
Class Counts: [11, 3, 15, 4, 9, 20, 8, 15, 2, 6, 1, 0, 3, 7, 10, 15, 1, 4, 17, 4, 25, 2, 12, 13, 3, 8, 11, 3, 23, 3, 3, 10, 8, 21, 11, 1, 3, 11, 7, 7, 18, 14, 5, 11, 10, 12, 14, 5, 9, 12]
Finished Testing


  _warn_prf(average, modifier, msg_start, len(result))


In [10]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        batch_size, C, width, height = x.size()
        
        # 计算查询、键和值
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        
        # 计算注意力权重并应用到值上
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        
        # 调整输出并加上残差连接
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

class CNNModel(nn.Module):
    def __init__(self, num_classes=50):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.attention1 = SelfAttention(128)

        self.fc1 = nn.Linear(128*56*56, 1024)
        # self.fc2 = nn.Linear(4096, 512)
        # self.fc3 = nn.Linear(512, num_classes)
        self.fc2 = nn.Linear(1024, num_classes)
        self.activation = nn.ReLU()
        self.flatten = nn.Flatten()
        self.pool = nn.MaxPool2d(2, 2)
        # self.dropout25 = nn.Dropout(0.25)
        self.dropout50 = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.pool(self.activation(self.bn1(self.conv1(x))))
        x = self.pool(self.activation(self.bn2(self.conv2(x))))
        x = self.attention1(x)
        x = self.dropout50(self.flatten(x))
        x = self.dropout50(self.activation(self.fc1(x)))
        # x = self.dropout50(self.activation(self.fc2(x)))
        # x = self.fc3(x)
        x = self.fc2(x)
        return x


for lr in [1e-5]:
    model = CNNModel(num_classes=50).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    num_epochs = 10
    save_path = 'record/Q2_less_layers'
    model_path = f'{save_path}/model_weight/model_weight_lr{lr}_bn_dropout50.pth'
    best_val_loss = float('inf')
    best_model = None
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        pbar = tqdm(data_loader_train, colour='green', total=len(data_loader_train))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Train Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()
        running_loss_avg = running_loss / len(dataset_train)
        running_acc_avg = running_acc / len(dataset_train)
        train_loss.append(running_loss_avg)
        train_acc.append(running_acc_avg)
        print(f"Epoch {epoch+1}, Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")

        model.eval()
        with torch.no_grad():
            class_counts = [0] * 50
            running_loss = 0.0
            running_acc = 0.0
            pbar = tqdm(data_loader_val, colour='red', total=len(data_loader_val))
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predictions = outputs.argmax(dim=1)
                loss = criterion(outputs, labels)
                running_loss += loss.item()*len(labels)
                running_acc += (predictions == labels).float().sum().item()
                pbar.set_description(f'Valid Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
                pbar.update()

                for pred_label in predictions:
                    class_counts[pred_label] += 1

            running_loss_avg = running_loss / len(dataset_val)
            running_acc_avg = running_acc / len(dataset_val)
            val_loss.append(running_loss_avg)
            val_acc.append(running_acc_avg)
            print(f"Validation Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")
            print(f"Class Counts: {class_counts}")
            if running_loss_avg < best_val_loss:
                best_val_loss = running_loss
                best_model = model
                torch.save(model.state_dict(), model_path)
        print("======================================")
    print("Finished Training")


    model.eval()
    with torch.no_grad():
        class_counts = [0] * 50
        running_loss = 0.0
        running_acc = 0.0
        predictions = []
        true_labels = []
        pbar = tqdm(data_loader_test, colour='red', total=len(data_loader_test))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs, 1)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            loss = criterion(outputs, labels)
            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Test Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()

        for pred_label in predictions:
            class_counts[pred_label] += 1
        
        precision = precision_score(true_labels, predictions, average='weighted')
        recall = recall_score(true_labels, predictions, average='weighted')
        f1 = f1_score(true_labels, predictions, average='weighted')
        print(f"Test Loss: {running_loss/len(dataset_test)}, Accuracy: {running_acc/len(dataset_test)}")
        print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
        print(f"Class Counts: {class_counts}")
    print("Finished Testing")

    train_record = {"train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc}
    test_record = {"true_labels": true_labels, "predictions": predictions, "precision": precision, "recall": recall, "f1": f1, "class_counts": class_counts}


    with open(f'{save_path}/train_record/train_record_lr{lr}_bn_dropout50.pkl', 'wb') as f:
        pickle.dump(train_record, f)
    with open(f'{save_path}/test_record/test_record_lr{lr}_bn_dropout50.pkl', 'wb') as f:
        pickle.dump(test_record, f)


Train Iter: 001/010  Loss: 3.4875: 100%|[32m██████████[0m| 1979/1979 [04:22<00:00,  7.54it/s]


Epoch 1, Loss: 3.6021053779412293, Accuracy: 0.08938018160284247


Valid Iter: 001/010  Loss: 2.4024: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 22.41it/s]


Validation Loss: 3.2657573223114014, Accuracy: 0.15777777777777777
Class Counts: [15, 12, 8, 39, 16, 2, 1, 2, 1, 4, 2, 2, 0, 0, 9, 5, 18, 0, 7, 15, 7, 10, 12, 34, 0, 1, 1, 2, 8, 2, 8, 0, 18, 17, 26, 2, 2, 6, 6, 7, 9, 16, 10, 2, 6, 10, 33, 6, 22, 9]


Train Iter: 002/010  Loss: 3.3236: 100%|[32m██████████[0m| 1979/1979 [04:23<00:00,  7.52it/s]


Epoch 2, Loss: 3.329084640240133, Accuracy: 0.1338491906829846


Valid Iter: 002/010  Loss: 2.1629: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 22.28it/s]


Validation Loss: 3.0691118314531116, Accuracy: 0.19555555555555557
Class Counts: [16, 12, 12, 15, 28, 11, 1, 7, 0, 0, 2, 0, 0, 0, 10, 18, 8, 0, 8, 5, 3, 3, 12, 19, 1, 2, 4, 6, 4, 13, 2, 1, 3, 12, 28, 43, 8, 14, 6, 11, 2, 5, 11, 6, 12, 5, 8, 10, 25, 18]


Train Iter: 003/010  Loss: 3.4495: 100%|[32m██████████[0m| 1979/1979 [04:24<00:00,  7.50it/s]


Epoch 3, Loss: 3.1528836255802073, Accuracy: 0.16674299249901303


Valid Iter: 003/010  Loss: 1.7792: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 21.02it/s]


Validation Loss: 2.913243222236633, Accuracy: 0.22666666666666666
Class Counts: [8, 2, 21, 26, 18, 5, 4, 6, 3, 7, 7, 2, 0, 2, 17, 13, 5, 5, 10, 8, 16, 7, 14, 4, 7, 4, 6, 10, 7, 9, 3, 3, 6, 18, 17, 15, 10, 9, 6, 24, 5, 2, 3, 13, 7, 8, 14, 1, 15, 18]


Train Iter: 004/010  Loss: 2.9382: 100%|[32m██████████[0m| 1979/1979 [04:24<00:00,  7.50it/s]


Epoch 4, Loss: 2.9906473729773277, Accuracy: 0.19876825898144493


Valid Iter: 004/010  Loss: 2.2322: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 22.30it/s]


Validation Loss: 2.8271414947509768, Accuracy: 0.25333333333333335
Class Counts: [4, 18, 17, 16, 20, 11, 11, 11, 3, 1, 2, 1, 3, 2, 13, 8, 4, 5, 13, 3, 12, 5, 7, 6, 7, 6, 6, 7, 13, 9, 8, 4, 6, 21, 17, 11, 11, 8, 3, 13, 12, 15, 4, 13, 11, 9, 15, 2, 16, 7]


Train Iter: 005/010  Loss: 2.9288: 100%|[32m██████████[0m| 1979/1979 [04:23<00:00,  7.50it/s]


Epoch 5, Loss: 2.8346489287834604, Accuracy: 0.23253059613106988


Valid Iter: 005/010  Loss: 1.3665: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 21.95it/s]


Validation Loss: 2.800333530637953, Accuracy: 0.22
Class Counts: [6, 4, 26, 14, 20, 13, 7, 6, 5, 0, 0, 5, 3, 16, 1, 8, 4, 4, 19, 7, 22, 5, 5, 11, 2, 1, 14, 9, 8, 8, 6, 4, 9, 16, 22, 4, 4, 7, 13, 4, 9, 11, 10, 11, 9, 11, 13, 5, 14, 15]


Train Iter: 006/010  Loss: 2.7430: 100%|[32m██████████[0m| 1979/1979 [04:24<00:00,  7.47it/s]


Epoch 6, Loss: 2.6666165786950993, Accuracy: 0.2726727200947493


Valid Iter: 006/010  Loss: 1.2419: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 18.70it/s]


Validation Loss: 2.736761645211114, Accuracy: 0.2111111111111111
Class Counts: [6, 9, 14, 13, 15, 2, 5, 6, 1, 8, 0, 6, 5, 11, 8, 13, 7, 3, 14, 24, 13, 4, 6, 11, 9, 6, 14, 12, 20, 5, 10, 3, 9, 9, 9, 4, 5, 5, 9, 19, 9, 10, 8, 11, 9, 5, 11, 6, 18, 11]


Train Iter: 007/010  Loss: 2.2880: 100%|[32m██████████[0m| 1979/1979 [04:24<00:00,  7.49it/s]


Epoch 7, Loss: 2.4869349667398155, Accuracy: 0.31919463087248323


Valid Iter: 007/010  Loss: 1.0271: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 22.41it/s]


Validation Loss: 2.651998969184028, Accuracy: 0.26
Class Counts: [6, 7, 12, 15, 15, 5, 5, 8, 5, 6, 3, 1, 3, 5, 15, 5, 5, 6, 12, 21, 11, 9, 5, 11, 20, 9, 7, 10, 8, 13, 6, 6, 9, 7, 16, 8, 13, 9, 7, 11, 6, 10, 7, 15, 9, 10, 10, 4, 15, 9]


Train Iter: 008/010  Loss: 2.3329: 100%|[32m██████████[0m| 1979/1979 [04:24<00:00,  7.49it/s]


Epoch 8, Loss: 2.291834581271961, Accuracy: 0.36923805763916306


Valid Iter: 008/010  Loss: 0.8180: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 22.07it/s]


Validation Loss: 2.646349369949765, Accuracy: 0.26
Class Counts: [6, 14, 13, 5, 18, 5, 4, 11, 10, 6, 3, 5, 6, 2, 10, 11, 10, 9, 14, 11, 16, 4, 6, 6, 13, 8, 10, 15, 7, 6, 8, 4, 13, 3, 7, 3, 6, 12, 6, 12, 16, 7, 13, 15, 8, 10, 16, 4, 16, 7]


Train Iter: 009/010  Loss: 2.0092: 100%|[32m██████████[0m| 1979/1979 [04:23<00:00,  7.51it/s]


Epoch 9, Loss: 2.0796312836362025, Accuracy: 0.42321358073430715


Valid Iter: 009/010  Loss: 0.6387: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 22.42it/s]


Validation Loss: 2.646788297759162, Accuracy: 0.27555555555555555
Class Counts: [8, 7, 9, 9, 18, 5, 4, 12, 5, 2, 3, 1, 4, 4, 13, 5, 10, 11, 11, 6, 19, 6, 4, 12, 4, 8, 12, 6, 7, 8, 15, 6, 13, 7, 5, 19, 17, 8, 13, 9, 15, 9, 7, 9, 9, 11, 10, 6, 13, 16]


Train Iter: 010/010  Loss: 2.1437: 100%|[32m██████████[0m| 1979/1979 [04:24<00:00,  7.48it/s]


Epoch 10, Loss: 1.8503187420525384, Accuracy: 0.4885748124753257


Valid Iter: 010/010  Loss: 0.9684: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 22.36it/s]


Validation Loss: 2.6346687671873306, Accuracy: 0.2822222222222222
Class Counts: [7, 13, 11, 10, 17, 6, 5, 14, 5, 6, 2, 6, 3, 9, 13, 11, 6, 14, 9, 14, 13, 4, 14, 8, 3, 10, 12, 3, 7, 12, 8, 0, 11, 12, 11, 9, 14, 6, 6, 11, 15, 4, 8, 10, 9, 8, 10, 9, 14, 8]
Finished Training


Test Iter: 010/010  Loss: 0.9684: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 17.78it/s]

Test Loss: 0.9684146642684937, Accuracy: 0.2577777777777778
Precision: 0.28378178194354664, Recall: 0.2577777777777778, F1: 0.2538292060880296
Class Counts: [8, 10, 9, 12, 7, 11, 5, 12, 3, 4, 3, 6, 2, 15, 16, 17, 13, 15, 13, 8, 7, 2, 7, 13, 7, 15, 6, 6, 6, 8, 4, 6, 8, 26, 10, 10, 8, 6, 10, 10, 10, 7, 2, 8, 7, 11, 9, 11, 12, 9]
Finished Testing





In [5]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        batch_size, C, width, height = x.size()
        
        # 计算查询、键和值
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        
        # 计算注意力权重并应用到值上
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        
        # 调整输出并加上残差连接
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

class CNNModel(nn.Module):
    def __init__(self, num_classes=50):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.attention1 = SelfAttention(256)

        self.fc1 = nn.Linear(256*28*28, 1024)
        self.fc2 = nn.Linear(1024, num_classes)
        self.activation = nn.ReLU()
        self.flatten = nn.Flatten()
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout25 = nn.Dropout(0.25)
        self.dropout50 = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.pool(self.activation(self.bn1(self.conv1(x))))
        x = self.pool(self.activation(self.bn2(self.conv2(x))))
        x = self.pool(self.activation(self.bn3(self.conv3(x))))
        x = self.attention1(x)
        x = self.dropout25(self.flatten(x))
        x = self.dropout50(self.activation(self.fc1(x)))
        x = self.fc2(x)
        return x


for lr in [1e-5]:
    model = CNNModel(num_classes=50).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    num_epochs = 10
    save_path = 'record/Q2_less_layers'
    rec_name = '6layers'
    model_path = f'{save_path}/model_weight/model_weight_lr{lr}_{rec_name}.pth'
    best_val_loss = float('inf')
    best_model = None
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        pbar = tqdm(data_loader_train, colour='green', total=len(data_loader_train))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Train Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()
        running_loss_avg = running_loss / len(dataset_train)
        running_acc_avg = running_acc / len(dataset_train)
        train_loss.append(running_loss_avg)
        train_acc.append(running_acc_avg)
        print(f"Epoch {epoch+1}, Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")

        model.eval()
        with torch.no_grad():
            class_counts = [0] * 50
            running_loss = 0.0
            running_acc = 0.0
            pbar = tqdm(data_loader_val, colour='red', total=len(data_loader_val))
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predictions = outputs.argmax(dim=1)
                loss = criterion(outputs, labels)
                running_loss += loss.item()*len(labels)
                running_acc += (predictions == labels).float().sum().item()
                pbar.set_description(f'Valid Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
                pbar.update()

                for pred_label in predictions:
                    class_counts[pred_label] += 1

            running_loss_avg = running_loss / len(dataset_val)
            running_acc_avg = running_acc / len(dataset_val)
            val_loss.append(running_loss_avg)
            val_acc.append(running_acc_avg)
            print(f"Validation Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")
            print(f"Class Counts: {class_counts}")
            if running_loss_avg < best_val_loss:
                best_val_loss = running_loss
                best_model = model
                torch.save(model.state_dict(), model_path)
        print("======================================")
    print("Finished Training")


    model.eval()
    with torch.no_grad():
        class_counts = [0] * 50
        running_loss = 0.0
        running_acc = 0.0
        predictions = []
        true_labels = []
        pbar = tqdm(data_loader_test, colour='red', total=len(data_loader_test))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs, 1)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            loss = criterion(outputs, labels)
            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Test Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()

        for pred_label in predictions:
            class_counts[pred_label] += 1
        
        precision = precision_score(true_labels, predictions, average='weighted')
        recall = recall_score(true_labels, predictions, average='weighted')
        f1 = f1_score(true_labels, predictions, average='weighted')
        print(f"Test Loss: {running_loss/len(dataset_test)}, Accuracy: {running_acc/len(dataset_test)}")
        print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
        print(f"Class Counts: {class_counts}")
    print("Finished Testing")

    train_record = {"train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc}
    test_record = {"true_labels": true_labels, "predictions": predictions, "precision": precision, "recall": recall, "f1": f1, "class_counts": class_counts}


    with open(f'{save_path}/train_record/train_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(train_record, f)
    with open(f'{save_path}/test_record/test_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(test_record, f)


Train Iter: 001/010  Loss: 3.4272: 100%|[32m██████████[0m| 1979/1979 [02:52<00:00, 11.50it/s]


Epoch 1, Loss: 3.4936507674125736, Accuracy: 0.10836162652980655


Valid Iter: 001/010  Loss: 1.8520: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.47it/s]


Validation Loss: 3.0952224588394164, Accuracy: 0.19555555555555557
Class Counts: [18, 12, 9, 14, 25, 4, 1, 4, 2, 0, 0, 2, 0, 1, 1, 27, 8, 0, 2, 3, 6, 5, 4, 30, 8, 0, 0, 16, 1, 22, 20, 1, 17, 15, 13, 4, 7, 11, 13, 10, 3, 11, 3, 22, 6, 10, 20, 1, 18, 20]


Train Iter: 002/010  Loss: 2.9536: 100%|[32m██████████[0m| 1979/1979 [02:52<00:00, 11.46it/s]


Epoch 2, Loss: 3.1504145529144987, Accuracy: 0.16885906040268456


Valid Iter: 002/010  Loss: 1.6306: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.55it/s]


Validation Loss: 2.8539898522694904, Accuracy: 0.22
Class Counts: [15, 7, 14, 5, 37, 12, 4, 7, 4, 1, 0, 0, 1, 2, 4, 0, 15, 2, 6, 7, 7, 11, 12, 20, 6, 7, 2, 11, 4, 14, 1, 7, 13, 10, 19, 13, 2, 9, 7, 21, 7, 5, 24, 17, 5, 9, 22, 2, 13, 7]


Train Iter: 003/010  Loss: 2.7541: 100%|[32m██████████[0m| 1979/1979 [02:52<00:00, 11.49it/s]


Epoch 3, Loss: 2.934891892872542, Accuracy: 0.21413343861034348


Valid Iter: 003/010  Loss: 1.3608: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.74it/s]


Validation Loss: 2.7147503333621557, Accuracy: 0.26222222222222225
Class Counts: [3, 1, 16, 16, 10, 14, 4, 9, 3, 3, 3, 0, 1, 2, 6, 4, 22, 8, 6, 6, 9, 6, 21, 6, 3, 13, 1, 17, 7, 8, 4, 3, 13, 21, 5, 14, 10, 15, 5, 18, 22, 13, 9, 18, 11, 7, 15, 2, 11, 6]


Train Iter: 004/010  Loss: 2.7054: 100%|[32m██████████[0m| 1979/1979 [02:50<00:00, 11.63it/s]


Epoch 4, Loss: 2.7310952638986787, Accuracy: 0.2589498618239242


Valid Iter: 004/010  Loss: 1.2528: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.33it/s]


Validation Loss: 2.607575257619222, Accuracy: 0.3022222222222222
Class Counts: [10, 1, 19, 10, 22, 14, 5, 6, 6, 4, 7, 1, 2, 3, 10, 13, 5, 4, 11, 18, 5, 8, 9, 6, 4, 15, 3, 17, 5, 7, 6, 8, 3, 14, 17, 6, 12, 5, 6, 9, 18, 13, 11, 21, 12, 7, 5, 8, 14, 5]


Train Iter: 005/010  Loss: 2.7843: 100%|[32m██████████[0m| 1979/1979 [02:51<00:00, 11.53it/s]


Epoch 5, Loss: 2.5337819131202863, Accuracy: 0.3056296881168575


Valid Iter: 005/010  Loss: 0.6859: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.60it/s]


Validation Loss: 2.5518478679656984, Accuracy: 0.3288888888888889
Class Counts: [3, 5, 22, 8, 13, 9, 7, 6, 10, 2, 5, 2, 2, 6, 15, 12, 9, 5, 6, 22, 10, 10, 6, 4, 2, 15, 2, 11, 14, 5, 4, 10, 7, 5, 16, 11, 10, 10, 10, 11, 16, 14, 7, 16, 8, 8, 14, 3, 13, 9]


Train Iter: 006/010  Loss: 2.6855: 100%|[32m██████████[0m| 1979/1979 [02:51<00:00, 11.53it/s]


Epoch 6, Loss: 2.3259029870197043, Accuracy: 0.35608369522305566


Valid Iter: 006/010  Loss: 0.6224: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.46it/s]


Validation Loss: 2.474023402531942, Accuracy: 0.3288888888888889
Class Counts: [3, 10, 16, 10, 18, 10, 8, 10, 6, 6, 7, 1, 3, 1, 10, 15, 12, 3, 6, 11, 13, 7, 4, 9, 11, 13, 8, 10, 5, 4, 4, 6, 8, 10, 10, 12, 14, 11, 7, 13, 15, 11, 6, 11, 11, 11, 11, 6, 16, 7]


Train Iter: 007/010  Loss: 2.0357: 100%|[32m██████████[0m| 1979/1979 [02:50<00:00, 11.61it/s]


Epoch 7, Loss: 2.1015909775057375, Accuracy: 0.4164232135807343


Valid Iter: 007/010  Loss: 0.4421: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.44it/s]


Validation Loss: 2.418163794146644, Accuracy: 0.34
Class Counts: [6, 9, 15, 8, 13, 12, 8, 8, 12, 5, 5, 1, 4, 5, 14, 6, 8, 11, 13, 10, 7, 6, 8, 5, 6, 13, 7, 9, 9, 7, 9, 6, 11, 7, 8, 10, 13, 4, 8, 16, 17, 10, 7, 18, 7, 7, 9, 7, 13, 13]


Train Iter: 008/010  Loss: 1.5950: 100%|[32m██████████[0m| 1979/1979 [02:51<00:00, 11.54it/s]


Epoch 8, Loss: 1.8552572194550097, Accuracy: 0.48508487958941965


Valid Iter: 008/010  Loss: 0.3371: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 26.15it/s]


Validation Loss: 2.4174457348717584, Accuracy: 0.34444444444444444
Class Counts: [7, 7, 13, 11, 19, 8, 3, 9, 4, 2, 7, 2, 4, 4, 9, 8, 8, 12, 15, 15, 13, 4, 4, 8, 9, 7, 6, 16, 6, 15, 3, 9, 10, 5, 12, 5, 12, 10, 6, 13, 16, 8, 10, 14, 7, 10, 18, 8, 12, 7]


Train Iter: 009/010  Loss: 1.8435: 100%|[32m██████████[0m| 1979/1979 [02:52<00:00, 11.49it/s]


Epoch 9, Loss: 1.5978395756044925, Accuracy: 0.5565574417686537


Valid Iter: 009/010  Loss: 0.4885: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.28it/s]


Validation Loss: 2.3632292921013303, Accuracy: 0.34
Class Counts: [8, 8, 12, 6, 21, 9, 5, 9, 7, 6, 4, 3, 4, 7, 9, 13, 8, 6, 20, 9, 7, 7, 11, 7, 6, 8, 3, 12, 9, 11, 3, 6, 4, 12, 9, 14, 13, 7, 7, 22, 17, 7, 2, 18, 11, 10, 10, 5, 12, 6]


Train Iter: 010/010  Loss: 0.9270: 100%|[32m██████████[0m| 1979/1979 [02:52<00:00, 11.47it/s]


Epoch 10, Loss: 1.340624063271004, Accuracy: 0.6348835373075404


Valid Iter: 010/010  Loss: 0.1901: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.98it/s]


Validation Loss: 2.3425543109575906, Accuracy: 0.3377777777777778
Class Counts: [8, 9, 16, 8, 21, 8, 2, 5, 7, 3, 8, 3, 5, 7, 10, 9, 14, 10, 8, 9, 8, 13, 15, 9, 3, 9, 8, 7, 7, 9, 4, 10, 8, 7, 10, 10, 12, 10, 11, 8, 13, 9, 10, 17, 13, 10, 9, 5, 11, 5]
Finished Training


Test Iter: 010/010  Loss: 0.1901: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.92it/s]

Test Loss: 0.19008785486221313, Accuracy: 0.37333333333333335
Precision: 0.37970499435205324, Recall: 0.37333333333333335, F1: 0.3620732739568653
Class Counts: [12, 10, 12, 8, 8, 8, 8, 7, 5, 6, 4, 7, 5, 11, 11, 11, 16, 9, 9, 13, 3, 13, 17, 10, 4, 11, 8, 9, 7, 10, 5, 9, 5, 10, 8, 7, 8, 8, 8, 6, 16, 11, 9, 9, 15, 12, 12, 9, 8, 3]
Finished Testing





In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        batch_size, C, width, height = x.size()
        
        # 计算查询、键和值
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        
        # 计算注意力权重并应用到值上
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        
        # 调整输出并加上残差连接
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

class CNNModel(nn.Module):
    def __init__(self, num_classes=50):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.attention1 = SelfAttention(256)

        self.fc1 = nn.Linear(256*28*28, 2048)
        self.fc2 = nn.Linear(2048, num_classes)
        self.activation = nn.ReLU()
        self.flatten = nn.Flatten()
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout25 = nn.Dropout(0.25)
        self.dropout50 = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.pool(self.activation(self.bn1(self.conv1(x))))
        x = self.pool(self.activation(self.bn2(self.conv2(x))))
        x = self.pool(self.activation(self.bn3(self.conv3(x))))
        x = self.attention1(x)
        x = self.dropout25(self.flatten(x))
        x = self.dropout50(self.activation(self.fc1(x)))
        x = self.fc2(x)
        return x


for lr in [1e-5]:
    model = CNNModel(num_classes=50).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    num_epochs = 10
    save_path = 'record/Q2_less_layers'
    rec_name = '6layers_2048'
    model_path = f'{save_path}/model_weight/model_weight_lr{lr}_{rec_name}.pth'
    best_val_loss = float('inf')
    best_model = None
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        pbar = tqdm(data_loader_train, colour='green', total=len(data_loader_train))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Train Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()
        running_loss_avg = running_loss / len(dataset_train)
        running_acc_avg = running_acc / len(dataset_train)
        train_loss.append(running_loss_avg)
        train_acc.append(running_acc_avg)
        print(f"Epoch {epoch+1}, Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")

        model.eval()
        with torch.no_grad():
            class_counts = [0] * 50
            running_loss = 0.0
            running_acc = 0.0
            pbar = tqdm(data_loader_val, colour='red', total=len(data_loader_val))
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predictions = outputs.argmax(dim=1)
                loss = criterion(outputs, labels)
                running_loss += loss.item()*len(labels)
                running_acc += (predictions == labels).float().sum().item()
                pbar.set_description(f'Valid Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
                pbar.update()

                for pred_label in predictions:
                    class_counts[pred_label] += 1

            running_loss_avg = running_loss / len(dataset_val)
            running_acc_avg = running_acc / len(dataset_val)
            val_loss.append(running_loss_avg)
            val_acc.append(running_acc_avg)
            print(f"Validation Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")
            print(f"Class Counts: {class_counts}")
            if running_loss_avg < best_val_loss:
                best_val_loss = running_loss
                best_model = model
                torch.save(model.state_dict(), model_path)
        print("======================================")
    print("Finished Training")


    model.eval()
    with torch.no_grad():
        class_counts = [0] * 50
        running_loss = 0.0
        running_acc = 0.0
        predictions = []
        true_labels = []
        pbar = tqdm(data_loader_test, colour='red', total=len(data_loader_test))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs, 1)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            loss = criterion(outputs, labels)

            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Test Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()

        for pred_label in predictions:
            class_counts[pred_label] += 1
        
        precision = precision_score(true_labels, predictions, average='weighted')
        recall = recall_score(true_labels, predictions, average='weighted')
        f1 = f1_score(true_labels, predictions, average='weighted')
        print(f"Test Loss: {running_loss/len(dataset_test)}, Accuracy: {running_acc/len(dataset_test)}")
        print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
        print(f"Class Counts: {class_counts}")
    print("Finished Testing")

    train_record = {"train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc}
    test_record = {"true_labels": true_labels, "predictions": predictions, "precision": precision, "recall": recall, "f1": f1, "class_counts": class_counts}


    with open(f'{save_path}/train_record/train_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(train_record, f)
    with open(f'{save_path}/test_record/test_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(test_record, f)


Train Iter: 001/010  Loss: 3.0179: 100%|[32m██████████[0m| 1979/1979 [03:43<00:00,  8.84it/s]


Epoch 1, Loss: 3.3699071711680655, Accuracy: 0.12876431109356495


Valid Iter: 001/010  Loss: 1.2763: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.32it/s]


Validation Loss: 2.9545362122853596, Accuracy: 0.21555555555555556
Class Counts: [5, 10, 14, 7, 28, 12, 8, 0, 5, 2, 2, 0, 1, 1, 20, 3, 0, 6, 4, 11, 15, 11, 6, 12, 4, 1, 0, 14, 0, 10, 7, 0, 28, 8, 6, 16, 5, 12, 19, 3, 6, 27, 10, 18, 7, 9, 13, 6, 15, 23]


Train Iter: 002/010  Loss: 2.7731: 100%|[32m██████████[0m| 1979/1979 [03:42<00:00,  8.88it/s]


Epoch 2, Loss: 3.001345942343265, Accuracy: 0.19723647848401105


Valid Iter: 002/010  Loss: 1.4420: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 23.80it/s]


Validation Loss: 2.7752434004677666, Accuracy: 0.2288888888888889
Class Counts: [11, 4, 17, 33, 8, 8, 4, 4, 4, 1, 1, 0, 3, 7, 17, 2, 4, 6, 9, 8, 17, 9, 8, 22, 0, 8, 0, 9, 13, 8, 5, 13, 2, 10, 14, 8, 22, 5, 12, 17, 10, 12, 0, 19, 6, 5, 18, 8, 15, 4]


Train Iter: 003/010  Loss: 2.7871: 100%|[32m██████████[0m| 1979/1979 [03:42<00:00,  8.89it/s]


Epoch 3, Loss: 2.7563749495128658, Accuracy: 0.25334386103434664


Valid Iter: 003/010  Loss: 1.0524: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 22.86it/s]


Validation Loss: 2.58701041592492, Accuracy: 0.30666666666666664
Class Counts: [2, 8, 14, 9, 22, 10, 6, 9, 10, 0, 2, 1, 2, 10, 5, 6, 8, 3, 3, 15, 10, 7, 14, 12, 1, 13, 3, 14, 9, 5, 13, 13, 8, 22, 11, 7, 6, 14, 7, 21, 12, 9, 6, 13, 7, 12, 12, 5, 9, 10]


Train Iter: 004/010  Loss: 2.3469: 100%|[32m██████████[0m| 1979/1979 [03:41<00:00,  8.92it/s]


Epoch 4, Loss: 2.510671392452844, Accuracy: 0.30959336754836164


Valid Iter: 004/010  Loss: 0.5544: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.62it/s]


Validation Loss: 2.4979684766133627, Accuracy: 0.3111111111111111
Class Counts: [4, 8, 12, 4, 24, 8, 8, 5, 9, 1, 5, 0, 2, 9, 8, 7, 4, 2, 10, 20, 8, 6, 9, 9, 8, 9, 2, 15, 9, 8, 7, 3, 9, 6, 9, 15, 11, 6, 7, 18, 21, 9, 19, 14, 7, 12, 15, 5, 11, 13]


Train Iter: 005/010  Loss: 1.9801: 100%|[32m██████████[0m| 1979/1979 [03:44<00:00,  8.83it/s]


Epoch 5, Loss: 2.245434084599356, Accuracy: 0.37831819976312675


Valid Iter: 005/010  Loss: 0.5439: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 21.51it/s]


Validation Loss: 2.4034217691421507, Accuracy: 0.31555555555555553
Class Counts: [4, 7, 11, 7, 20, 8, 8, 3, 10, 2, 7, 0, 3, 12, 14, 14, 4, 8, 15, 11, 13, 5, 8, 12, 3, 9, 4, 10, 8, 13, 6, 4, 9, 7, 10, 10, 8, 8, 6, 14, 12, 11, 9, 16, 11, 11, 18, 4, 14, 9]


Train Iter: 006/010  Loss: 2.0426: 100%|[32m██████████[0m| 1979/1979 [03:43<00:00,  8.84it/s]


Epoch 6, Loss: 1.9587714490201509, Accuracy: 0.4543229372285827


Valid Iter: 006/010  Loss: 0.5232: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.42it/s]


Validation Loss: 2.3730297552214727, Accuracy: 0.3333333333333333
Class Counts: [5, 10, 17, 12, 20, 9, 6, 6, 12, 3, 9, 3, 5, 7, 11, 17, 6, 10, 9, 7, 9, 8, 4, 7, 3, 7, 5, 9, 7, 9, 17, 10, 10, 7, 7, 12, 11, 11, 6, 12, 8, 9, 8, 11, 10, 12, 9, 3, 11, 14]


Train Iter: 007/010  Loss: 1.8416: 100%|[32m██████████[0m| 1979/1979 [03:44<00:00,  8.83it/s]


Epoch 7, Loss: 1.641339238033528, Accuracy: 0.5450296091590999


Valid Iter: 007/010  Loss: 0.3787: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 23.91it/s]


Validation Loss: 2.3321826790438758, Accuracy: 0.34444444444444444
Class Counts: [8, 5, 13, 11, 18, 11, 6, 10, 10, 2, 7, 5, 6, 11, 6, 7, 8, 6, 13, 6, 9, 7, 15, 9, 13, 5, 4, 14, 9, 15, 10, 4, 9, 8, 6, 10, 16, 8, 6, 9, 9, 6, 8, 11, 14, 14, 11, 5, 9, 8]


Train Iter: 008/010  Loss: 1.2481: 100%|[32m██████████[0m| 1979/1979 [03:41<00:00,  8.93it/s]


Epoch 8, Loss: 1.3115808126778645, Accuracy: 0.6438689301223846


Valid Iter: 008/010  Loss: 0.2073: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.68it/s]


Validation Loss: 2.382743229866028, Accuracy: 0.3422222222222222
Class Counts: [5, 9, 9, 10, 19, 12, 6, 6, 5, 5, 9, 4, 5, 8, 7, 9, 12, 11, 7, 6, 12, 14, 14, 9, 3, 17, 5, 11, 7, 6, 5, 9, 16, 9, 7, 10, 6, 10, 6, 11, 10, 9, 6, 12, 6, 10, 12, 10, 13, 11]


Train Iter: 009/010  Loss: 0.7943: 100%|[32m██████████[0m| 1979/1979 [03:40<00:00,  8.97it/s]


Epoch 9, Loss: 1.0095623683722368, Accuracy: 0.7378760363205685


Valid Iter: 009/010  Loss: 0.1818: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.00it/s]


Validation Loss: 2.3753189714749654, Accuracy: 0.3288888888888889
Class Counts: [9, 10, 13, 11, 17, 6, 5, 5, 7, 7, 6, 3, 4, 8, 9, 18, 8, 15, 9, 23, 13, 12, 7, 10, 11, 9, 5, 12, 7, 7, 4, 3, 9, 9, 7, 7, 7, 9, 8, 5, 11, 11, 5, 11, 9, 9, 8, 7, 12, 13]


Train Iter: 010/010  Loss: 1.1189: 100%|[32m██████████[0m| 1979/1979 [03:39<00:00,  9.02it/s]


Epoch 10, Loss: 0.7323346991168149, Accuracy: 0.8213028030003948


Valid Iter: 010/010  Loss: 0.0581: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.80it/s]


Validation Loss: 2.386636883947584, Accuracy: 0.3466666666666667
Class Counts: [4, 12, 12, 11, 18, 9, 6, 4, 6, 10, 6, 4, 4, 8, 5, 11, 11, 22, 11, 7, 14, 12, 7, 7, 8, 10, 5, 11, 7, 10, 5, 5, 7, 5, 8, 9, 8, 11, 5, 8, 11, 13, 9, 8, 5, 9, 12, 10, 11, 19]
Finished Training


Test Iter: 010/010  Loss: 0.0581: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.77it/s]

Test Loss: 0.058064818382263184, Accuracy: 0.3511111111111111
Precision: 0.36408938773644656, Recall: 0.3511111111111111, F1: 0.34448776757808924
Class Counts: [8, 8, 13, 9, 12, 7, 7, 5, 5, 10, 3, 12, 3, 9, 6, 9, 15, 17, 12, 7, 7, 6, 8, 11, 8, 11, 4, 8, 7, 11, 9, 12, 3, 8, 7, 12, 8, 11, 10, 7, 8, 12, 12, 9, 7, 11, 18, 6, 8, 14]
Finished Testing





In [6]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        batch_size, C, width, height = x.size()
        
        # 计算查询、键和值
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        
        # 计算注意力权重并应用到值上
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        
        # 调整输出并加上残差连接
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

class CNNModel(nn.Module):
    def __init__(self, num_classes=50):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.attention1 = SelfAttention(256)

        self.fc1 = nn.Linear(256*28*28, 1024)
        self.fc2 = nn.Linear(1024, num_classes)
        self.activation = nn.ReLU()
        self.flatten = nn.Flatten()
        self.pool = nn.MaxPool2d(2, 2)
        # self.dropout25 = nn.Dropout(0.25)
        self.dropout50 = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.pool(self.activation(self.bn1(self.conv1(x))))
        x = self.pool(self.activation(self.bn2(self.conv2(x))))
        x = self.pool(self.activation(self.bn3(self.conv3(x))))
        x = self.attention1(x)
        x = self.dropout50(self.flatten(x))
        x = self.dropout50(self.activation(self.fc1(x)))
        x = self.fc2(x)
        return x


for lr in [1e-5]:
    model = CNNModel(num_classes=50).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    num_epochs = 10
    save_path = 'record/Q2_less_layers'
    rec_name = '6layers_dropout50'
    model_path = f'{save_path}/model_weight/model_weight_lr{lr}_{rec_name}.pth'
    best_val_loss = float('inf')
    best_model = None
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        pbar = tqdm(data_loader_train, colour='green', total=len(data_loader_train))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Train Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()
        running_loss_avg = running_loss / len(dataset_train)
        running_acc_avg = running_acc / len(dataset_train)
        train_loss.append(running_loss_avg)
        train_acc.append(running_acc_avg)
        print(f"Epoch {epoch+1}, Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")

        model.eval()
        with torch.no_grad():
            class_counts = [0] * 50
            running_loss = 0.0
            running_acc = 0.0
            pbar = tqdm(data_loader_val, colour='red', total=len(data_loader_val))
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predictions = outputs.argmax(dim=1)
                loss = criterion(outputs, labels)
                running_loss += loss.item()*len(labels)
                running_acc += (predictions == labels).float().sum().item()
                pbar.set_description(f'Valid Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
                pbar.update()

                for pred_label in predictions:
                    class_counts[pred_label] += 1

            running_loss_avg = running_loss / len(dataset_val)
            running_acc_avg = running_acc / len(dataset_val)
            val_loss.append(running_loss_avg)
            val_acc.append(running_acc_avg)
            print(f"Validation Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")
            print(f"Class Counts: {class_counts}")
            if running_loss_avg < best_val_loss:
                best_val_loss = running_loss
                best_model = model
                torch.save(model.state_dict(), model_path)
        print("======================================")
    print("Finished Training")


    model.eval()
    with torch.no_grad():
        class_counts = [0] * 50
        running_loss = 0.0
        running_acc = 0.0
        predictions = []
        true_labels = []
        pbar = tqdm(data_loader_test, colour='red', total=len(data_loader_test))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs, 1)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            loss = criterion(outputs, labels)

            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Test Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()

        for pred_label in predictions:
            class_counts[pred_label] += 1
        
        precision = precision_score(true_labels, predictions, average='weighted')
        recall = recall_score(true_labels, predictions, average='weighted')
        f1 = f1_score(true_labels, predictions, average='weighted')
        print(f"Test Loss: {running_loss/len(dataset_test)}, Accuracy: {running_acc/len(dataset_test)}")
        print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
        print(f"Class Counts: {class_counts}")
    print("Finished Testing")

    train_record = {"train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc}
    test_record = {"true_labels": true_labels, "predictions": predictions, "precision": precision, "recall": recall, "f1": f1, "class_counts": class_counts}


    with open(f'{save_path}/train_record/train_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(train_record, f)
    with open(f'{save_path}/test_record/test_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(test_record, f)


Train Iter: 001/010  Loss: 3.3324: 100%|[32m██████████[0m| 1979/1979 [02:51<00:00, 11.55it/s]


Epoch 1, Loss: 3.472888354923807, Accuracy: 0.11063560994867745


Valid Iter: 001/010  Loss: 2.1215: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.24it/s]


Validation Loss: 3.070883650249905, Accuracy: 0.19333333333333333
Class Counts: [10, 1, 8, 43, 12, 4, 3, 1, 9, 3, 1, 0, 0, 1, 21, 5, 7, 4, 6, 1, 2, 10, 2, 17, 17, 5, 1, 34, 3, 1, 1, 2, 12, 9, 14, 7, 4, 17, 2, 16, 5, 9, 37, 8, 10, 15, 11, 13, 11, 15]


Train Iter: 002/010  Loss: 3.0650: 100%|[32m██████████[0m| 1979/1979 [02:52<00:00, 11.48it/s]


Epoch 2, Loss: 3.128225928407105, Accuracy: 0.17285432293722858


Valid Iter: 002/010  Loss: 1.3241: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.42it/s]


Validation Loss: 2.86743846628401, Accuracy: 0.2288888888888889
Class Counts: [6, 12, 14, 25, 17, 7, 4, 6, 5, 3, 2, 0, 3, 10, 9, 3, 3, 2, 8, 13, 8, 9, 13, 5, 4, 10, 1, 16, 4, 7, 3, 0, 16, 32, 6, 8, 8, 7, 4, 9, 20, 15, 9, 21, 4, 10, 10, 10, 10, 19]


Train Iter: 003/010  Loss: 2.7458: 100%|[32m██████████[0m| 1979/1979 [02:51<00:00, 11.51it/s]


Epoch 3, Loss: 2.9113066682427755, Accuracy: 0.21888669561784446


Valid Iter: 003/010  Loss: 1.3594: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.43it/s]


Validation Loss: 2.7119095542695786, Accuracy: 0.2733333333333333
Class Counts: [6, 11, 15, 16, 29, 10, 8, 4, 7, 3, 1, 0, 3, 3, 6, 9, 10, 1, 10, 10, 9, 13, 8, 7, 4, 14, 5, 8, 6, 12, 4, 7, 10, 11, 13, 9, 12, 7, 7, 14, 16, 8, 15, 13, 8, 11, 7, 6, 11, 13]


Train Iter: 004/010  Loss: 2.4583: 100%|[32m██████████[0m| 1979/1979 [02:51<00:00, 11.56it/s]


Epoch 4, Loss: 2.6917230732319526, Accuracy: 0.2679510461902882


Valid Iter: 004/010  Loss: 0.9209: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 26.25it/s]


Validation Loss: 2.63084832350413, Accuracy: 0.29555555555555557
Class Counts: [6, 1, 16, 10, 28, 4, 5, 9, 4, 3, 3, 2, 4, 7, 13, 6, 5, 3, 9, 15, 10, 3, 12, 7, 16, 7, 4, 7, 6, 19, 5, 2, 13, 14, 9, 5, 13, 11, 7, 15, 23, 8, 6, 18, 13, 9, 11, 4, 12, 8]


Train Iter: 005/010  Loss: 2.7734: 100%|[32m██████████[0m| 1979/1979 [02:51<00:00, 11.53it/s]


Epoch 5, Loss: 2.4702518017697703, Accuracy: 0.32443742597710223


Valid Iter: 005/010  Loss: 0.9961: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.80it/s]


Validation Loss: 2.529768679406908, Accuracy: 0.2911111111111111
Class Counts: [8, 1, 12, 10, 22, 8, 7, 15, 10, 0, 1, 0, 2, 5, 18, 5, 5, 8, 16, 14, 10, 3, 8, 11, 18, 11, 5, 9, 10, 8, 5, 6, 7, 11, 7, 8, 13, 8, 13, 12, 10, 14, 8, 15, 11, 10, 7, 5, 12, 8]


Train Iter: 006/010  Loss: 2.3791: 100%|[32m██████████[0m| 1979/1979 [02:52<00:00, 11.49it/s]


Epoch 6, Loss: 2.241417315434525, Accuracy: 0.38171338333991317


Valid Iter: 006/010  Loss: 0.3327: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.58it/s]


Validation Loss: 2.4671743529372745, Accuracy: 0.3022222222222222
Class Counts: [13, 4, 11, 9, 17, 12, 8, 11, 7, 4, 1, 2, 4, 4, 8, 9, 6, 7, 6, 16, 14, 5, 7, 15, 7, 9, 1, 13, 5, 19, 1, 7, 9, 13, 14, 12, 18, 14, 8, 7, 13, 7, 5, 16, 9, 9, 5, 4, 14, 11]


Train Iter: 007/010  Loss: 2.2504: 100%|[32m██████████[0m| 1979/1979 [02:52<00:00, 11.49it/s]


Epoch 7, Loss: 1.9883548521628458, Accuracy: 0.4496012633241216


Valid Iter: 007/010  Loss: 0.4599: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.93it/s]


Validation Loss: 2.3747906226581996, Accuracy: 0.3466666666666667
Class Counts: [6, 10, 10, 10, 13, 9, 5, 16, 12, 5, 6, 3, 5, 6, 10, 7, 7, 8, 12, 18, 9, 5, 12, 6, 13, 14, 3, 10, 8, 8, 3, 6, 6, 4, 9, 10, 12, 9, 9, 16, 10, 4, 11, 16, 12, 8, 11, 5, 15, 8]


Train Iter: 008/010  Loss: 1.5884: 100%|[32m██████████[0m| 1979/1979 [02:51<00:00, 11.53it/s]


Epoch 8, Loss: 1.7225304572515416, Accuracy: 0.5218476115278327


Valid Iter: 008/010  Loss: 0.3200: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.38it/s]


Validation Loss: 2.3756455278396604, Accuracy: 0.34444444444444444
Class Counts: [6, 9, 13, 16, 15, 10, 6, 6, 11, 2, 7, 4, 5, 6, 17, 9, 9, 9, 9, 16, 10, 2, 8, 7, 14, 15, 1, 12, 5, 8, 4, 2, 16, 5, 11, 7, 9, 5, 6, 20, 12, 3, 13, 11, 8, 10, 14, 4, 12, 11]


Train Iter: 009/010  Loss: 1.1538: 100%|[32m██████████[0m| 1979/1979 [02:52<00:00, 11.46it/s]


Epoch 9, Loss: 1.4619719098012438, Accuracy: 0.60001579155152


Valid Iter: 009/010  Loss: 0.4756: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.36it/s]


Validation Loss: 2.3619219273991057, Accuracy: 0.3288888888888889
Class Counts: [7, 11, 13, 18, 20, 12, 4, 8, 15, 3, 4, 4, 6, 7, 9, 7, 8, 7, 9, 14, 11, 14, 5, 8, 4, 8, 10, 11, 5, 7, 6, 13, 7, 6, 8, 7, 4, 11, 6, 14, 14, 5, 11, 15, 8, 8, 11, 6, 14, 7]


Train Iter: 010/010  Loss: 1.2606: 100%|[32m██████████[0m| 1979/1979 [02:51<00:00, 11.54it/s]


Epoch 10, Loss: 1.2142170782123118, Accuracy: 0.6728464271614686


Valid Iter: 010/010  Loss: 0.2201: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.49it/s]


Validation Loss: 2.34850836303499, Accuracy: 0.3333333333333333
Class Counts: [9, 11, 10, 16, 15, 11, 6, 7, 12, 5, 4, 4, 5, 13, 5, 4, 6, 11, 9, 11, 12, 5, 6, 9, 10, 13, 4, 13, 5, 9, 3, 6, 13, 10, 8, 9, 11, 7, 6, 13, 13, 5, 6, 18, 7, 14, 11, 7, 12, 11]
Finished Training


Test Iter: 010/010  Loss: 0.2201: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 23.73it/s]


Test Loss: 0.22008007764816284, Accuracy: 0.37333333333333335
Precision: 0.3627560495060495, Recall: 0.37333333333333335, F1: 0.3582361576444225
Class Counts: [12, 9, 11, 11, 6, 11, 8, 7, 11, 3, 5, 2, 5, 15, 6, 9, 6, 4, 12, 9, 9, 10, 13, 13, 3, 9, 7, 9, 9, 11, 5, 4, 8, 16, 8, 11, 8, 8, 9, 8, 10, 11, 13, 10, 10, 15, 14, 8, 10, 9]
Finished Testing


In [8]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        batch_size, C, width, height = x.size()
        
        # 计算查询、键和值
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        
        # 计算注意力权重并应用到值上
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        
        # 调整输出并加上残差连接
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

class CNNModel(nn.Module):
    def __init__(self, num_classes=50):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.attention1 = SelfAttention(256)

        self.fc1 = nn.Linear(256*28*28, 2048)
        self.fc2 = nn.Linear(2048, num_classes)
        self.activation = nn.ReLU()
        self.flatten = nn.Flatten()
        self.pool = nn.MaxPool2d(2, 2)
        # self.dropout25 = nn.Dropout(0.25)
        self.dropout50 = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.pool(self.activation(self.bn1(self.conv1(x))))
        x = self.pool(self.activation(self.bn2(self.conv2(x))))
        x = self.pool(self.activation(self.bn3(self.conv3(x))))
        x = self.attention1(x)
        x = self.dropout50(self.flatten(x))
        x = self.dropout50(self.activation(self.fc1(x)))
        x = self.fc2(x)
        return x


for lr in [1e-5]:
    model = CNNModel(num_classes=50).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    num_epochs = 10
    save_path = 'record/Q2_less_layers'
    rec_name = '6layers_2048_dropout50_weightdecay'
    model_path = f'{save_path}/model_weight/model_weight_lr{lr}_{rec_name}.pth'
    best_val_loss = float('inf')
    best_model = None
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        pbar = tqdm(data_loader_train, colour='green', total=len(data_loader_train))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Train Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()
        running_loss_avg = running_loss / len(dataset_train)
        running_acc_avg = running_acc / len(dataset_train)
        train_loss.append(running_loss_avg)
        train_acc.append(running_acc_avg)
        print(f"Epoch {epoch+1}, Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")

        model.eval()
        with torch.no_grad():
            class_counts = [0] * 50
            running_loss = 0.0
            running_acc = 0.0
            pbar = tqdm(data_loader_val, colour='red', total=len(data_loader_val))
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predictions = outputs.argmax(dim=1)
                loss = criterion(outputs, labels)
                running_loss += loss.item()*len(labels)
                running_acc += (predictions == labels).float().sum().item()
                pbar.set_description(f'Valid Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
                pbar.update()

                for pred_label in predictions:
                    class_counts[pred_label] += 1

            running_loss_avg = running_loss / len(dataset_val)
            running_acc_avg = running_acc / len(dataset_val)
            val_loss.append(running_loss_avg)
            val_acc.append(running_acc_avg)
            print(f"Validation Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")
            print(f"Class Counts: {class_counts}")
            if running_loss_avg < best_val_loss:
                best_val_loss = running_loss
                best_model = model
                torch.save(model.state_dict(), model_path)
        print("======================================")
    print("Finished Training")


    model.eval()
    with torch.no_grad():
        class_counts = [0] * 50
        running_loss = 0.0
        running_acc = 0.0
        predictions = []
        true_labels = []
        pbar = tqdm(data_loader_test, colour='red', total=len(data_loader_test))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs, 1)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            loss = criterion(outputs, labels)

            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Test Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()

        for pred_label in predictions:
            class_counts[pred_label] += 1
        
        precision = precision_score(true_labels, predictions, average='weighted')
        recall = recall_score(true_labels, predictions, average='weighted')
        f1 = f1_score(true_labels, predictions, average='weighted')
        print(f"Test Loss: {running_loss/len(dataset_test)}, Accuracy: {running_acc/len(dataset_test)}")
        print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
        print(f"Class Counts: {class_counts}")
    print("Finished Testing")

    train_record = {"train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc}
    test_record = {"true_labels": true_labels, "predictions": predictions, "precision": precision, "recall": recall, "f1": f1, "class_counts": class_counts}


    with open(f'{save_path}/train_record/train_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(train_record, f)
    with open(f'{save_path}/test_record/test_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(test_record, f)


Train Iter: 001/010  Loss: 3.3190: 100%|[32m██████████[0m| 1979/1979 [03:49<00:00,  8.64it/s]


Epoch 1, Loss: 3.3806793784455165, Accuracy: 0.12958547177260166


Valid Iter: 001/010  Loss: 2.2356: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.04it/s]


Validation Loss: 2.991858303281996, Accuracy: 0.18222222222222223
Class Counts: [22, 35, 5, 0, 28, 11, 3, 3, 0, 5, 0, 0, 1, 0, 6, 19, 0, 1, 4, 15, 17, 9, 26, 6, 10, 2, 4, 21, 4, 8, 2, 9, 8, 16, 12, 6, 4, 9, 10, 8, 8, 5, 18, 2, 13, 17, 19, 6, 10, 3]


Train Iter: 002/010  Loss: 2.9143: 100%|[32m██████████[0m| 1979/1979 [03:49<00:00,  8.63it/s]


Epoch 2, Loss: 2.9955717641150486, Accuracy: 0.20048953809711803


Valid Iter: 002/010  Loss: 1.4628: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.25it/s]


Validation Loss: 2.7451139227549235, Accuracy: 0.23555555555555555
Class Counts: [4, 11, 17, 2, 30, 16, 5, 2, 1, 2, 5, 0, 1, 6, 11, 17, 13, 1, 7, 10, 16, 12, 16, 7, 7, 12, 1, 8, 11, 3, 2, 3, 9, 9, 14, 1, 8, 17, 8, 19, 13, 9, 4, 18, 14, 13, 12, 0, 13, 10]


Train Iter: 003/010  Loss: 2.7760: 100%|[32m██████████[0m| 1979/1979 [03:50<00:00,  8.60it/s]


Epoch 3, Loss: 2.7439792676052934, Accuracy: 0.2572917489143308


Valid Iter: 003/010  Loss: 0.7716: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.85it/s]


Validation Loss: 2.6423211240768434, Accuracy: 0.26666666666666666
Class Counts: [6, 5, 9, 21, 6, 8, 6, 11, 3, 0, 5, 1, 5, 6, 9, 22, 6, 7, 8, 17, 12, 5, 7, 5, 7, 13, 10, 7, 12, 6, 1, 14, 11, 7, 14, 8, 3, 4, 7, 24, 18, 2, 21, 9, 9, 13, 15, 3, 10, 12]


Train Iter: 004/010  Loss: 2.7396: 100%|[32m██████████[0m| 1979/1979 [03:49<00:00,  8.62it/s]


Epoch 4, Loss: 2.4773737301621286, Accuracy: 0.3196999605211212


Valid Iter: 004/010  Loss: 0.7260: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.69it/s]


Validation Loss: 2.4843821557362875, Accuracy: 0.3022222222222222
Class Counts: [4, 10, 22, 7, 20, 15, 7, 8, 10, 3, 4, 5, 5, 8, 7, 16, 10, 8, 5, 12, 7, 4, 4, 5, 2, 4, 10, 10, 8, 10, 6, 6, 9, 7, 12, 16, 9, 8, 8, 10, 16, 13, 6, 9, 10, 9, 13, 7, 14, 12]


Train Iter: 005/010  Loss: 1.8839: 100%|[32m██████████[0m| 1979/1979 [03:49<00:00,  8.64it/s]


Epoch 5, Loss: 2.192806824323833, Accuracy: 0.3941097512830636


Valid Iter: 005/010  Loss: 0.8791: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 19.66it/s]


Validation Loss: 2.4334191274642945, Accuracy: 0.29555555555555557
Class Counts: [4, 8, 13, 10, 17, 12, 6, 10, 5, 4, 5, 1, 6, 2, 7, 9, 10, 12, 7, 14, 20, 5, 12, 8, 3, 9, 5, 11, 8, 11, 0, 10, 8, 7, 11, 6, 11, 10, 15, 21, 13, 2, 6, 19, 8, 7, 11, 5, 12, 14]


Train Iter: 006/010  Loss: 1.8443: 100%|[32m██████████[0m| 1979/1979 [03:50<00:00,  8.59it/s]


Epoch 6, Loss: 1.8774829417107088, Accuracy: 0.47660481642321356


Valid Iter: 006/010  Loss: 0.3534: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.19it/s]


Validation Loss: 2.3738451608022055, Accuracy: 0.32222222222222224
Class Counts: [4, 10, 10, 11, 12, 14, 5, 5, 5, 3, 7, 4, 4, 7, 18, 8, 8, 9, 8, 8, 7, 4, 18, 5, 6, 8, 8, 18, 14, 5, 9, 6, 8, 11, 7, 6, 10, 6, 11, 8, 12, 15, 12, 20, 6, 8, 10, 10, 12, 10]


Train Iter: 007/010  Loss: 1.8957: 100%|[32m██████████[0m| 1979/1979 [03:48<00:00,  8.66it/s]


Epoch 7, Loss: 1.5414551266633707, Accuracy: 0.5755388866956178


Valid Iter: 007/010  Loss: 0.2696: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.02it/s]


Validation Loss: 2.3185087766912247, Accuracy: 0.33555555555555555
Class Counts: [4, 12, 13, 8, 14, 9, 7, 10, 10, 3, 6, 0, 7, 9, 15, 6, 9, 18, 7, 8, 18, 4, 6, 4, 9, 6, 10, 15, 10, 6, 4, 10, 13, 7, 5, 11, 11, 8, 7, 19, 11, 4, 6, 10, 6, 8, 16, 7, 11, 13]


Train Iter: 008/010  Loss: 1.3435: 100%|[32m██████████[0m| 1979/1979 [03:49<00:00,  8.62it/s]


Epoch 8, Loss: 1.2137994844942661, Accuracy: 0.6758152388472167


Valid Iter: 008/010  Loss: 0.3469: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.42it/s]


Validation Loss: 2.336239240169525, Accuracy: 0.34
Class Counts: [5, 10, 10, 9, 10, 5, 6, 11, 8, 5, 3, 5, 5, 7, 18, 11, 9, 11, 17, 7, 22, 3, 8, 6, 6, 10, 8, 11, 6, 7, 1, 9, 16, 11, 4, 12, 14, 8, 7, 11, 11, 7, 6, 16, 8, 8, 6, 12, 9, 15]


Train Iter: 009/010  Loss: 0.6700: 100%|[32m██████████[0m| 1979/1979 [03:47<00:00,  8.70it/s]


Epoch 9, Loss: 0.9137363653695023, Accuracy: 0.7641058033951835


Valid Iter: 009/010  Loss: 0.3857: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.24it/s]


Validation Loss: 2.3180692270067005, Accuracy: 0.3688888888888889
Class Counts: [6, 6, 11, 10, 9, 7, 4, 7, 6, 10, 8, 3, 6, 3, 18, 12, 7, 9, 8, 13, 11, 6, 10, 10, 18, 13, 3, 11, 6, 7, 4, 7, 13, 8, 6, 8, 11, 6, 10, 9, 11, 10, 20, 12, 8, 7, 10, 6, 14, 12]


Train Iter: 010/010  Loss: 0.6594: 100%|[32m██████████[0m| 1979/1979 [03:48<00:00,  8.65it/s]


Epoch 10, Loss: 0.6632199572396759, Accuracy: 0.8438531385708646


Valid Iter: 010/010  Loss: 0.5004: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.19it/s]


Validation Loss: 2.346022649606069, Accuracy: 0.3711111111111111
Class Counts: [11, 7, 10, 10, 17, 7, 3, 8, 9, 10, 4, 5, 4, 8, 9, 12, 5, 12, 8, 14, 14, 10, 16, 9, 7, 13, 6, 12, 8, 6, 3, 7, 9, 9, 7, 8, 6, 12, 12, 10, 11, 7, 7, 16, 12, 6, 6, 6, 12, 10]
Finished Training


Test Iter: 010/010  Loss: 0.5004: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 23.50it/s]

Test Loss: 0.5004361271858215, Accuracy: 0.3622222222222222
Precision: 0.36843709231944527, Recall: 0.3622222222222222, F1: 0.3480189904545134
Class Counts: [9, 9, 9, 1, 13, 7, 7, 14, 12, 9, 4, 5, 6, 8, 10, 9, 12, 9, 11, 17, 8, 6, 12, 11, 4, 14, 2, 6, 5, 9, 7, 8, 4, 13, 10, 10, 7, 11, 10, 8, 10, 11, 11, 12, 10, 12, 12, 9, 11, 6]
Finished Testing





In [9]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        batch_size, C, width, height = x.size()
        
        # 计算查询、键和值
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        
        # 计算注意力权重并应用到值上
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        
        # 调整输出并加上残差连接
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

class CNNModel(nn.Module):
    def __init__(self, num_classes=50):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.attention1 = SelfAttention(256)

        self.fc1 = nn.Linear(256*28*28, 2048)
        self.fc2 = nn.Linear(2048, num_classes)
        self.activation = nn.ReLU()
        self.flatten = nn.Flatten()
        self.pool = nn.MaxPool2d(2, 2)
        # self.dropout25 = nn.Dropout(0.25)
        self.dropout50 = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.pool(self.activation(self.bn1(self.conv1(x))))
        x = self.pool(self.activation(self.bn2(self.conv2(x))))
        x = self.pool(self.activation(self.bn3(self.conv3(x))))
        x = self.attention1(x)
        x = self.dropout50(self.flatten(x))
        x = self.dropout50(self.activation(self.fc1(x)))
        x = self.fc2(x)
        return x


for lr in [1e-5]:
    model = CNNModel(num_classes=50).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
    num_epochs = 10
    save_path = 'record/Q2_less_layers'
    rec_name = '6layers_2048_dropout50_adamW'
    model_path = f'{save_path}/model_weight/model_weight_lr{lr}_{rec_name}.pth'
    best_val_loss = float('inf')
    best_model = None
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        pbar = tqdm(data_loader_train, colour='green', total=len(data_loader_train))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Train Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()
        running_loss_avg = running_loss / len(dataset_train)
        running_acc_avg = running_acc / len(dataset_train)
        train_loss.append(running_loss_avg)
        train_acc.append(running_acc_avg)
        print(f"Epoch {epoch+1}, Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")

        model.eval()
        with torch.no_grad():
            class_counts = [0] * 50
            running_loss = 0.0
            running_acc = 0.0
            pbar = tqdm(data_loader_val, colour='red', total=len(data_loader_val))
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predictions = outputs.argmax(dim=1)
                loss = criterion(outputs, labels)
                running_loss += loss.item()*len(labels)
                running_acc += (predictions == labels).float().sum().item()
                pbar.set_description(f'Valid Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
                pbar.update()

                for pred_label in predictions:
                    class_counts[pred_label] += 1

            running_loss_avg = running_loss / len(dataset_val)
            running_acc_avg = running_acc / len(dataset_val)
            val_loss.append(running_loss_avg)
            val_acc.append(running_acc_avg)
            print(f"Validation Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")
            print(f"Class Counts: {class_counts}")
            if running_loss_avg < best_val_loss:
                best_val_loss = running_loss
                best_model = model
                torch.save(model.state_dict(), model_path)
        print("======================================")
    print("Finished Training")


    model.eval()
    with torch.no_grad():
        class_counts = [0] * 50
        running_loss = 0.0
        running_acc = 0.0
        predictions = []
        true_labels = []
        pbar = tqdm(data_loader_test, colour='red', total=len(data_loader_test))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs, 1)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            loss = criterion(outputs, labels)

            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Test Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()

        for pred_label in predictions:
            class_counts[pred_label] += 1
        
        precision = precision_score(true_labels, predictions, average='weighted')
        recall = recall_score(true_labels, predictions, average='weighted')
        f1 = f1_score(true_labels, predictions, average='weighted')
        print(f"Test Loss: {running_loss/len(dataset_test)}, Accuracy: {running_acc/len(dataset_test)}")
        print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
        print(f"Class Counts: {class_counts}")
    print("Finished Testing")

    train_record = {"train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc}
    test_record = {"true_labels": true_labels, "predictions": predictions, "precision": precision, "recall": recall, "f1": f1, "class_counts": class_counts}


    with open(f'{save_path}/train_record/train_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(train_record, f)
    with open(f'{save_path}/test_record/test_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(test_record, f)


Train Iter: 001/010  Loss: 3.2933: 100%|[32m██████████[0m| 1979/1979 [03:44<00:00,  8.82it/s]


Epoch 1, Loss: 3.371419522149545, Accuracy: 0.1318278720884327


Valid Iter: 001/010  Loss: 2.0642: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.23it/s]


Validation Loss: 2.968172470728556, Accuracy: 0.21777777777777776
Class Counts: [3, 10, 5, 23, 19, 18, 2, 6, 4, 0, 0, 4, 0, 3, 9, 4, 1, 10, 1, 0, 17, 16, 9, 21, 1, 4, 2, 14, 22, 10, 4, 5, 3, 11, 25, 7, 12, 12, 4, 22, 9, 5, 0, 24, 10, 11, 13, 5, 20, 10]


Train Iter: 002/010  Loss: 3.2808: 100%|[32m██████████[0m| 1979/1979 [03:44<00:00,  8.81it/s]


Epoch 2, Loss: 3.0010386702668934, Accuracy: 0.1985945519147256


Valid Iter: 002/010  Loss: 1.5680: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.80it/s]


Validation Loss: 2.8043849881490073, Accuracy: 0.23777777777777778
Class Counts: [3, 4, 15, 7, 48, 14, 1, 3, 3, 0, 0, 1, 3, 0, 14, 16, 27, 6, 2, 7, 10, 8, 6, 9, 2, 1, 6, 13, 9, 12, 16, 0, 4, 12, 17, 10, 14, 7, 5, 15, 15, 6, 0, 15, 11, 8, 9, 3, 19, 14]


Train Iter: 003/010  Loss: 2.4335: 100%|[32m██████████[0m| 1979/1979 [03:46<00:00,  8.75it/s]


Epoch 3, Loss: 2.7665933227162562, Accuracy: 0.2505803395183577


Valid Iter: 003/010  Loss: 1.1063: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.12it/s]


Validation Loss: 2.6186005300945707, Accuracy: 0.27111111111111114
Class Counts: [3, 10, 11, 8, 19, 8, 6, 4, 7, 9, 3, 1, 3, 11, 6, 7, 10, 16, 3, 10, 11, 14, 13, 15, 5, 11, 2, 11, 12, 5, 2, 10, 10, 17, 9, 15, 5, 10, 6, 6, 20, 18, 7, 7, 9, 10, 13, 3, 9, 10]


Train Iter: 004/010  Loss: 2.5409: 100%|[32m██████████[0m| 1979/1979 [03:44<00:00,  8.80it/s]


Epoch 4, Loss: 2.512829645712574, Accuracy: 0.3103039873667588


Valid Iter: 004/010  Loss: 0.7784: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.85it/s]


Validation Loss: 2.5062840535905626, Accuracy: 0.31777777777777777
Class Counts: [8, 12, 11, 7, 17, 5, 7, 7, 19, 5, 3, 1, 2, 3, 8, 7, 9, 7, 6, 15, 11, 11, 9, 10, 6, 8, 6, 14, 5, 7, 4, 5, 7, 13, 15, 14, 8, 7, 10, 15, 14, 6, 7, 16, 7, 13, 16, 5, 12, 10]


Train Iter: 005/010  Loss: 2.5712: 100%|[32m██████████[0m| 1979/1979 [03:47<00:00,  8.71it/s]


Epoch 5, Loss: 2.2491621346404043, Accuracy: 0.3779549940781682


Valid Iter: 005/010  Loss: 0.7822: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.14it/s]


Validation Loss: 2.4279203769895767, Accuracy: 0.3244444444444444
Class Counts: [6, 7, 7, 12, 21, 8, 6, 7, 9, 5, 3, 2, 3, 7, 17, 7, 3, 8, 6, 16, 20, 5, 21, 4, 5, 13, 5, 19, 4, 8, 5, 7, 11, 13, 9, 6, 10, 8, 6, 10, 14, 12, 2, 19, 11, 12, 9, 4, 13, 5]


Train Iter: 006/010  Loss: 2.0018: 100%|[32m██████████[0m| 1979/1979 [03:44<00:00,  8.81it/s]


Epoch 6, Loss: 1.938186409378353, Accuracy: 0.46093959731543627


Valid Iter: 006/010  Loss: 0.5706: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.86it/s]


Validation Loss: 2.413328757815891, Accuracy: 0.31777777777777777
Class Counts: [9, 8, 9, 13, 18, 7, 4, 10, 6, 5, 7, 2, 4, 6, 8, 9, 4, 8, 8, 14, 17, 10, 11, 9, 5, 9, 12, 12, 3, 4, 11, 6, 12, 11, 8, 8, 15, 9, 7, 11, 12, 7, 6, 17, 11, 9, 9, 3, 13, 14]


Train Iter: 007/010  Loss: 2.0385: 100%|[32m██████████[0m| 1979/1979 [03:45<00:00,  8.79it/s]


Epoch 7, Loss: 1.614978291481039, Accuracy: 0.5541413343861035


Valid Iter: 007/010  Loss: 0.3280: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.08it/s]


Validation Loss: 2.379083115259806, Accuracy: 0.32
Class Counts: [7, 8, 14, 12, 17, 9, 4, 8, 10, 7, 7, 5, 4, 7, 14, 11, 4, 6, 6, 8, 16, 5, 13, 9, 11, 11, 3, 7, 6, 7, 10, 3, 12, 9, 7, 11, 13, 7, 8, 15, 10, 6, 8, 16, 8, 14, 9, 5, 11, 12]


Train Iter: 008/010  Loss: 1.2046: 100%|[32m██████████[0m| 1979/1979 [03:45<00:00,  8.79it/s]


Epoch 8, Loss: 1.2882313722647372, Accuracy: 0.6525542834583498


Valid Iter: 008/010  Loss: 0.2392: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.02it/s]


Validation Loss: 2.373409506612354, Accuracy: 0.3622222222222222
Class Counts: [10, 13, 9, 12, 21, 9, 6, 11, 8, 8, 3, 5, 4, 3, 14, 6, 10, 9, 6, 8, 15, 3, 16, 6, 6, 6, 6, 8, 8, 10, 8, 6, 8, 9, 7, 8, 11, 13, 7, 12, 14, 7, 3, 22, 7, 9, 7, 8, 13, 12]


Train Iter: 009/010  Loss: 0.8817: 100%|[32m██████████[0m| 1979/1979 [03:46<00:00,  8.75it/s]


Epoch 9, Loss: 0.9923234678670158, Accuracy: 0.7425503355704698


Valid Iter: 009/010  Loss: 0.6548: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.03it/s]


Validation Loss: 2.3479604694578384, Accuracy: 0.36
Class Counts: [12, 9, 10, 14, 23, 8, 2, 5, 8, 7, 7, 7, 4, 7, 14, 12, 3, 7, 9, 7, 13, 6, 10, 9, 10, 9, 5, 13, 4, 10, 7, 3, 10, 8, 7, 12, 10, 10, 10, 12, 12, 8, 6, 17, 6, 9, 10, 7, 11, 11]


Train Iter: 010/010  Loss: 1.2731: 100%|[32m██████████[0m| 1979/1979 [03:44<00:00,  8.81it/s]


Epoch 10, Loss: 0.7263026210473923, Accuracy: 0.8228819581523885


Valid Iter: 010/010  Loss: 0.1971: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.04it/s]


Validation Loss: 2.368212363587485, Accuracy: 0.36444444444444446
Class Counts: [7, 10, 11, 9, 23, 9, 4, 12, 13, 7, 4, 6, 5, 7, 13, 16, 7, 15, 8, 11, 12, 5, 5, 6, 4, 7, 6, 10, 6, 8, 12, 5, 7, 10, 9, 8, 14, 6, 7, 10, 12, 9, 7, 10, 8, 9, 11, 8, 10, 12]
Finished Training


Test Iter: 010/010  Loss: 0.1971: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.06it/s]

Test Loss: 0.1970612108707428, Accuracy: 0.36666666666666664
Precision: 0.3737601581425111, Recall: 0.36666666666666664, F1: 0.3616002912897515
Class Counts: [9, 12, 15, 6, 17, 10, 7, 12, 13, 10, 3, 6, 4, 10, 12, 14, 8, 10, 10, 9, 7, 5, 8, 3, 6, 12, 4, 7, 12, 8, 11, 7, 5, 12, 10, 8, 6, 7, 13, 7, 8, 10, 11, 7, 9, 10, 8, 11, 10, 11]
Finished Testing





In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        batch_size, C, width, height = x.size()
        
        # 计算查询、键和值
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        
        # 计算注意力权重并应用到值上
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        
        # 调整输出并加上残差连接
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

class CNNModel(nn.Module):
    def __init__(self, num_classes=50):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.attention1 = SelfAttention(256)

        self.fc1 = nn.Linear(256*28*28, 2048)
        self.fc2 = nn.Linear(2048, num_classes)
        self.activation = nn.ReLU()
        self.flatten = nn.Flatten()
        self.pool = nn.MaxPool2d(2, 2)
        # self.dropout25 = nn.Dropout(0.25)
        self.dropout50 = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.pool(self.activation(self.bn1(self.conv1(x))))
        x = self.pool(self.activation(self.bn2(self.conv2(x))))
        x = self.pool(self.activation(self.bn3(self.conv3(x))))
        x = self.attention1(x)
        x = self.dropout50(self.flatten(x))
        x = self.dropout50(self.activation(self.fc1(x)))
        x = self.fc2(x)
        return x


for lr in [1e-5]:
    model = CNNModel(num_classes=50).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.00025)
    num_epochs = 10
    save_path = 'record/Q2_less_layers'
    rec_name = '6layers_2048_dropout50_adamW_0.00025'
    model_path = f'{save_path}/model_weight/model_weight_lr{lr}_{rec_name}.pth'
    best_val_loss = float('inf')
    best_model = None
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        pbar = tqdm(data_loader_train, colour='green', total=len(data_loader_train))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Train Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()
        running_loss_avg = running_loss / len(dataset_train)
        running_acc_avg = running_acc / len(dataset_train)
        train_loss.append(running_loss_avg)
        train_acc.append(running_acc_avg)
        print(f"Epoch {epoch+1}, Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")

        model.eval()
        with torch.no_grad():
            class_counts = [0] * 50
            running_loss = 0.0
            running_acc = 0.0
            pbar = tqdm(data_loader_val, colour='red', total=len(data_loader_val))
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predictions = outputs.argmax(dim=1)
                loss = criterion(outputs, labels)
                running_loss += loss.item()*len(labels)
                running_acc += (predictions == labels).float().sum().item()
                pbar.set_description(f'Valid Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
                pbar.update()

                for pred_label in predictions:
                    class_counts[pred_label] += 1

            running_loss_avg = running_loss / len(dataset_val)
            running_acc_avg = running_acc / len(dataset_val)
            val_loss.append(running_loss_avg)
            val_acc.append(running_acc_avg)
            print(f"Validation Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")
            print(f"Class Counts: {class_counts}")
            if running_loss_avg < best_val_loss:
                best_val_loss = running_loss
                best_model = model
                torch.save(model.state_dict(), model_path)
        print("======================================")
    print("Finished Training")


    model.eval()
    with torch.no_grad():
        class_counts = [0] * 50
        running_loss = 0.0
        running_acc = 0.0
        predictions = []
        true_labels = []
        pbar = tqdm(data_loader_test, colour='red', total=len(data_loader_test))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs, 1)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            loss = criterion(outputs, labels)

            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Test Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()

        for pred_label in predictions:
            class_counts[pred_label] += 1
        
        precision = precision_score(true_labels, predictions, average='weighted')
        recall = recall_score(true_labels, predictions, average='weighted')
        f1 = f1_score(true_labels, predictions, average='weighted')
        print(f"Test Loss: {running_loss/len(dataset_test)}, Accuracy: {running_acc/len(dataset_test)}")
        print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
        print(f"Class Counts: {class_counts}")
    print("Finished Testing")

    train_record = {"train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc}
    test_record = {"true_labels": true_labels, "predictions": predictions, "precision": precision, "recall": recall, "f1": f1, "class_counts": class_counts}


    with open(f'{save_path}/train_record/train_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(train_record, f)
    with open(f'{save_path}/test_record/test_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(test_record, f)


Train Iter: 001/010  Loss: 3.0217: 100%|[32m██████████[0m| 1979/1979 [03:48<00:00,  8.65it/s]


Epoch 1, Loss: 3.3622765291381906, Accuracy: 0.13136991709435453


Valid Iter: 001/010  Loss: 2.1082: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.22it/s]


Validation Loss: 2.963844184875488, Accuracy: 0.18666666666666668
Class Counts: [14, 5, 22, 7, 22, 5, 9, 6, 2, 4, 0, 0, 0, 3, 9, 16, 4, 0, 10, 1, 27, 2, 6, 15, 9, 1, 1, 17, 3, 11, 5, 0, 20, 11, 10, 13, 11, 13, 2, 24, 12, 6, 13, 14, 7, 13, 10, 4, 18, 13]


Train Iter: 002/010  Loss: 2.7883: 100%|[32m██████████[0m| 1979/1979 [03:47<00:00,  8.69it/s]


Epoch 2, Loss: 2.984173423678236, Accuracy: 0.2028740623766285


Valid Iter: 002/010  Loss: 1.3706: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.27it/s]


Validation Loss: 2.750908899307251, Accuracy: 0.2577777777777778
Class Counts: [5, 10, 4, 33, 24, 9, 3, 4, 18, 2, 1, 0, 3, 2, 14, 4, 11, 12, 11, 13, 10, 7, 7, 15, 3, 14, 2, 8, 6, 2, 8, 1, 9, 9, 8, 3, 17, 11, 11, 8, 9, 13, 4, 27, 8, 8, 12, 7, 14, 6]


Train Iter: 003/010  Loss: 2.3641: 100%|[32m██████████[0m| 1979/1979 [03:47<00:00,  8.69it/s]


Epoch 3, Loss: 2.7402663750503544, Accuracy: 0.2561547572048954


Valid Iter: 003/010  Loss: 0.8221: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 23.52it/s]


Validation Loss: 2.6176902394824557, Accuracy: 0.27555555555555555
Class Counts: [7, 15, 8, 8, 30, 11, 6, 13, 5, 5, 3, 3, 4, 3, 7, 7, 10, 2, 2, 17, 9, 17, 7, 9, 14, 12, 4, 2, 25, 3, 5, 1, 5, 15, 13, 6, 9, 8, 7, 8, 18, 8, 1, 19, 9, 13, 15, 7, 7, 8]


Train Iter: 004/010  Loss: 2.2002: 100%|[32m██████████[0m| 1979/1979 [03:47<00:00,  8.69it/s]


Epoch 4, Loss: 2.478454423170997, Accuracy: 0.31838926174496646


Valid Iter: 004/010  Loss: 0.3995: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.67it/s]


Validation Loss: 2.491275388929579, Accuracy: 0.3022222222222222
Class Counts: [8, 11, 15, 14, 16, 3, 6, 6, 7, 6, 3, 3, 1, 6, 6, 12, 12, 9, 9, 18, 8, 12, 5, 6, 1, 9, 6, 8, 13, 9, 5, 10, 12, 10, 7, 8, 9, 12, 7, 15, 16, 7, 14, 10, 12, 11, 9, 8, 11, 9]


Train Iter: 005/010  Loss: 2.0639: 100%|[32m██████████[0m| 1979/1979 [03:46<00:00,  8.73it/s]


Epoch 5, Loss: 2.19708090954979, Accuracy: 0.39303592577970786


Valid Iter: 005/010  Loss: 0.3991: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.36it/s]


Validation Loss: 2.4123817547162374, Accuracy: 0.32222222222222224
Class Counts: [4, 3, 6, 10, 17, 11, 6, 14, 8, 5, 3, 0, 4, 11, 15, 6, 16, 9, 12, 17, 4, 10, 7, 5, 7, 7, 8, 11, 11, 5, 9, 7, 3, 10, 14, 11, 15, 6, 11, 15, 17, 5, 5, 13, 8, 12, 12, 4, 12, 9]


Train Iter: 006/010  Loss: 1.7949: 100%|[32m██████████[0m| 1979/1979 [03:46<00:00,  8.75it/s]


Epoch 6, Loss: 1.8869334942523641, Accuracy: 0.4754520331622582


Valid Iter: 006/010  Loss: 0.2723: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 23.95it/s]


Validation Loss: 2.3277376794815066, Accuracy: 0.34
Class Counts: [12, 6, 10, 11, 19, 8, 4, 11, 6, 5, 6, 1, 5, 5, 17, 5, 4, 7, 6, 11, 16, 6, 14, 12, 7, 9, 5, 7, 12, 11, 5, 6, 8, 14, 10, 14, 14, 6, 12, 9, 11, 9, 3, 13, 7, 8, 10, 12, 11, 10]


Train Iter: 007/010  Loss: 1.7049: 100%|[32m██████████[0m| 1979/1979 [03:46<00:00,  8.74it/s]


Epoch 7, Loss: 1.5582245665710996, Accuracy: 0.5688432688511647


Valid Iter: 007/010  Loss: 0.2214: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.55it/s]


Validation Loss: 2.3412856198681724, Accuracy: 0.34444444444444444
Class Counts: [3, 9, 8, 15, 21, 12, 4, 8, 5, 2, 6, 3, 4, 9, 10, 14, 6, 4, 13, 11, 8, 1, 14, 6, 10, 5, 10, 9, 7, 13, 4, 9, 13, 7, 4, 12, 16, 8, 11, 17, 12, 6, 4, 19, 12, 12, 10, 5, 12, 7]


Train Iter: 008/010  Loss: 1.0828: 100%|[32m██████████[0m| 1979/1979 [03:47<00:00,  8.71it/s]


Epoch 8, Loss: 1.2337543193615335, Accuracy: 0.6673509672325306


Valid Iter: 008/010  Loss: 0.4355: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 25.00it/s]


Validation Loss: 2.321160083346897, Accuracy: 0.3377777777777778
Class Counts: [9, 14, 6, 15, 18, 12, 6, 8, 3, 4, 11, 1, 6, 6, 11, 8, 7, 2, 12, 11, 12, 5, 6, 7, 8, 7, 9, 12, 7, 5, 4, 11, 11, 12, 13, 6, 9, 8, 13, 11, 14, 7, 11, 12, 8, 10, 10, 7, 16, 9]


Train Iter: 009/010  Loss: 1.0819: 100%|[32m██████████[0m| 1979/1979 [03:46<00:00,  8.73it/s]


Epoch 9, Loss: 0.94251945330322, Accuracy: 0.756525858665614


Valid Iter: 009/010  Loss: 0.1305: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.81it/s]


Validation Loss: 2.307536218563716, Accuracy: 0.3688888888888889
Class Counts: [8, 11, 7, 10, 21, 12, 4, 7, 5, 4, 7, 4, 5, 7, 11, 12, 11, 7, 10, 16, 12, 4, 8, 5, 7, 9, 6, 9, 6, 4, 5, 14, 6, 10, 8, 13, 9, 18, 11, 12, 12, 4, 5, 17, 8, 15, 8, 8, 6, 12]


Train Iter: 010/010  Loss: 0.6759: 100%|[32m██████████[0m| 1979/1979 [03:46<00:00,  8.73it/s]


Epoch 10, Loss: 0.6981750429961411, Accuracy: 0.8290248716936439


Valid Iter: 010/010  Loss: 0.3383: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.91it/s]


Validation Loss: 2.3225330283906724, Accuracy: 0.35333333333333333
Class Counts: [10, 8, 7, 16, 21, 10, 3, 7, 6, 8, 9, 3, 6, 11, 8, 5, 7, 9, 8, 12, 16, 4, 10, 8, 4, 3, 12, 10, 5, 12, 7, 8, 10, 12, 6, 11, 10, 13, 9, 10, 13, 5, 9, 14, 9, 10, 9, 6, 12, 9]
Finished Training


Test Iter: 010/010  Loss: 0.3383: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 24.42it/s]

Test Loss: 0.3383125066757202, Accuracy: 0.38
Precision: 0.39352695017400896, Recall: 0.38, F1: 0.3740236531390118
Class Counts: [6, 8, 11, 11, 12, 12, 8, 7, 5, 4, 6, 4, 6, 17, 9, 11, 12, 11, 14, 11, 11, 3, 8, 10, 8, 7, 4, 9, 10, 11, 5, 12, 6, 10, 9, 11, 3, 11, 11, 11, 8, 9, 12, 5, 8, 14, 12, 6, 10, 11]
Finished Testing





In [4]:
with open('images/images.pickle', 'rb') as f:
    img_data = pickle.load(f)
class CustomDataset(Dataset):
    def __init__(self, img_data, txt_file, transform=None):
        self.data = img_data
        with open(txt_file, 'r') as f:
            self.labels = f.readlines()
        self.transform = transform
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        img_path, label = self.labels[idx].strip().split()
        img_path = img_path.split('/')
        img = self.data[img_path[1]][img_path[2]]
        img = Image.open(io.BytesIO(img)).convert('RGB')
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(int(label))
        return img, label

transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset_train = CustomDataset(img_data=img_data, txt_file='images/train.txt', transform=transform_train)
dataset_val = CustomDataset(img_data=img_data, txt_file='images/val.txt', transform=transform)
dataset_test = CustomDataset(img_data=img_data, txt_file='images/test.txt', transform=transform)


data_loader_train = DataLoader(dataset_train, batch_size=32, shuffle=True, )
data_loader_val = DataLoader(dataset_val, batch_size=32, shuffle=False)
data_loader_test = DataLoader(dataset_test, batch_size=32, shuffle=False)
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        batch_size, C, width, height = x.size()
        
        # 计算查询、键和值
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        
        # 计算注意力权重并应用到值上
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        
        # 调整输出并加上残差连接
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

class CNNModel(nn.Module):
    def __init__(self, num_classes=50):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.attention1 = SelfAttention(256)

        self.fc1 = nn.Linear(256*28*28, 2048)
        self.fc2 = nn.Linear(2048, num_classes)
        self.activation = nn.ReLU()
        self.flatten = nn.Flatten()
        self.pool = nn.MaxPool2d(2, 2)
        # self.dropout25 = nn.Dropout(0.25)
        self.dropout50 = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.pool(self.activation(self.bn1(self.conv1(x))))
        x = self.pool(self.activation(self.bn2(self.conv2(x))))
        x = self.pool(self.activation(self.bn3(self.conv3(x))))
        x = self.attention1(x)
        x = self.dropout50(self.flatten(x))
        x = self.dropout50(self.activation(self.fc1(x)))
        x = self.fc2(x)
        return x


for lr in [1e-5]:
    model = CNNModel(num_classes=50).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.00025)
    num_epochs = 10
    save_path = 'record/Q2_less_layers'
    rec_name = '6layers_2048_dropout50_adamW_0.00025_augment'
    model_path = f'{save_path}/model_weight/model_weight_lr{lr}_{rec_name}.pth'
    best_val_loss = float('inf')
    best_model = None
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        pbar = tqdm(data_loader_train, colour='green', total=len(data_loader_train))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Train Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()
        running_loss_avg = running_loss / len(dataset_train)
        running_acc_avg = running_acc / len(dataset_train)
        train_loss.append(running_loss_avg)
        train_acc.append(running_acc_avg)
        print(f"Epoch {epoch+1}, Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")

        model.eval()
        with torch.no_grad():
            class_counts = [0] * 50
            running_loss = 0.0
            running_acc = 0.0
            pbar = tqdm(data_loader_val, colour='red', total=len(data_loader_val))
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predictions = outputs.argmax(dim=1)
                loss = criterion(outputs, labels)
                running_loss += loss.item()*len(labels)
                running_acc += (predictions == labels).float().sum().item()
                pbar.set_description(f'Valid Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
                pbar.update()

                for pred_label in predictions:
                    class_counts[pred_label] += 1

            running_loss_avg = running_loss / len(dataset_val)
            running_acc_avg = running_acc / len(dataset_val)
            val_loss.append(running_loss_avg)
            val_acc.append(running_acc_avg)
            print(f"Validation Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")
            print(f"Class Counts: {class_counts}")
            if running_loss_avg < best_val_loss:
                best_val_loss = running_loss
                best_model = model
                torch.save(model.state_dict(), model_path)
        print("======================================")
    print("Finished Training")


    model.eval()
    with torch.no_grad():
        class_counts = [0] * 50
        running_loss = 0.0
        running_acc = 0.0
        predictions = []
        true_labels = []
        pbar = tqdm(data_loader_test, colour='red', total=len(data_loader_test))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs, 1)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            loss = criterion(outputs, labels)

            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Test Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()

        for pred_label in predictions:
            class_counts[pred_label] += 1
        
        precision = precision_score(true_labels, predictions, average='weighted')
        recall = recall_score(true_labels, predictions, average='weighted')
        f1 = f1_score(true_labels, predictions, average='weighted')
        print(f"Test Loss: {running_loss/len(dataset_test)}, Accuracy: {running_acc/len(dataset_test)}")
        print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
        print(f"Class Counts: {class_counts}")
    print("Finished Testing")

    train_record = {"train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc}
    test_record = {"true_labels": true_labels, "predictions": predictions, "precision": precision, "recall": recall, "f1": f1, "class_counts": class_counts}


    with open(f'{save_path}/train_record/train_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(train_record, f)
    with open(f'{save_path}/test_record/test_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(test_record, f)


Train Iter: 001/010  Loss: 3.6530: 100%|[32m██████████[0m| 1979/1979 [06:31<00:00,  5.06it/s]


Epoch 1, Loss: 3.7190854269367315, Accuracy: 0.0704776944334781


Valid Iter: 001/010  Loss: 2.6371: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 20.62it/s]


Validation Loss: 3.3658720927768284, Accuracy: 0.13111111111111112
Class Counts: [14, 11, 17, 6, 22, 0, 0, 2, 0, 1, 0, 0, 0, 1, 8, 12, 3, 0, 7, 0, 16, 2, 0, 3, 0, 1, 0, 5, 24, 0, 0, 0, 103, 28, 16, 21, 11, 14, 1, 6, 3, 2, 0, 5, 8, 12, 3, 3, 16, 43]


Train Iter: 002/010  Loss: 3.7509: 100%|[32m██████████[0m| 1979/1979 [06:39<00:00,  4.96it/s]


Epoch 2, Loss: 3.5534130832643505, Accuracy: 0.0975759968416897


Valid Iter: 002/010  Loss: 2.7354: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 20.68it/s]


Validation Loss: 3.2091529422336156, Accuracy: 0.15555555555555556
Class Counts: [13, 2, 18, 15, 9, 1, 0, 1, 0, 0, 0, 1, 0, 2, 6, 4, 2, 7, 6, 1, 24, 1, 23, 7, 1, 15, 0, 10, 4, 7, 1, 5, 3, 38, 24, 4, 18, 19, 3, 25, 5, 19, 1, 9, 18, 9, 14, 11, 18, 26]


Train Iter: 003/010  Loss: 3.3217: 100%|[32m██████████[0m| 1979/1979 [06:35<00:00,  5.00it/s]


Epoch 3, Loss: 3.4619255359813814, Accuracy: 0.11371496249506514


Valid Iter: 003/010  Loss: 2.7233: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 20.71it/s]


Validation Loss: 3.1123110156589084, Accuracy: 0.16666666666666666
Class Counts: [8, 2, 14, 4, 23, 2, 0, 5, 1, 2, 2, 0, 0, 5, 12, 22, 6, 8, 7, 8, 7, 3, 3, 10, 1, 19, 6, 7, 4, 6, 0, 1, 6, 33, 26, 20, 11, 23, 3, 19, 14, 6, 13, 13, 9, 11, 9, 3, 18, 15]


Train Iter: 004/010  Loss: 3.3374: 100%|[32m██████████[0m| 1979/1979 [06:38<00:00,  4.97it/s]


Epoch 4, Loss: 3.4070958235100988, Accuracy: 0.12249506514015002


Valid Iter: 004/010  Loss: 3.1339: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 20.35it/s]


Validation Loss: 3.047256942325168, Accuracy: 0.19111111111111112
Class Counts: [6, 1, 19, 9, 20, 1, 4, 12, 3, 0, 0, 1, 0, 1, 14, 12, 4, 4, 4, 5, 23, 11, 16, 7, 15, 10, 2, 4, 2, 17, 3, 5, 18, 16, 9, 4, 10, 14, 8, 19, 4, 11, 8, 25, 12, 7, 6, 18, 19, 7]


Train Iter: 005/010  Loss: 3.3954: 100%|[32m██████████[0m| 1979/1979 [06:34<00:00,  5.02it/s]


Epoch 5, Loss: 3.3632060979224927, Accuracy: 0.12919068298460323


Valid Iter: 005/010  Loss: 3.0955: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 20.83it/s]


Validation Loss: 2.9662327935960557, Accuracy: 0.2
Class Counts: [13, 2, 14, 5, 27, 3, 7, 10, 2, 0, 0, 1, 0, 21, 5, 6, 3, 4, 8, 5, 18, 22, 7, 14, 2, 9, 12, 10, 6, 4, 5, 4, 11, 14, 6, 6, 20, 6, 12, 22, 12, 5, 7, 17, 22, 12, 9, 3, 17, 0]


Train Iter: 006/010  Loss: 3.3021: 100%|[32m██████████[0m| 1979/1979 [06:29<00:00,  5.08it/s]


Epoch 6, Loss: 3.321628724228851, Accuracy: 0.1381760757994473


Valid Iter: 006/010  Loss: 2.2943: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 20.90it/s]


Validation Loss: 2.9444251611497667, Accuracy: 0.21555555555555556
Class Counts: [8, 10, 3, 6, 23, 8, 6, 4, 2, 1, 2, 3, 0, 7, 5, 21, 4, 2, 4, 12, 31, 8, 13, 12, 2, 10, 5, 7, 3, 11, 3, 7, 8, 7, 6, 8, 10, 13, 10, 29, 3, 12, 14, 10, 12, 10, 22, 3, 22, 8]


Train Iter: 007/010  Loss: 3.1473: 100%|[32m██████████[0m| 1979/1979 [06:25<00:00,  5.13it/s]


Epoch 7, Loss: 3.2766119434651864, Accuracy: 0.1466719305171733


Valid Iter: 007/010  Loss: 2.1262: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 20.70it/s]


Validation Loss: 2.893872439066569, Accuracy: 0.21777777777777776
Class Counts: [10, 5, 9, 10, 17, 2, 8, 4, 4, 0, 0, 0, 0, 5, 2, 18, 5, 7, 3, 7, 13, 9, 26, 14, 7, 18, 8, 12, 4, 7, 4, 4, 16, 8, 4, 7, 14, 7, 4, 27, 8, 6, 34, 9, 13, 9, 17, 6, 14, 5]


Train Iter: 008/010  Loss: 3.2701: 100%|[32m██████████[0m| 1979/1979 [06:21<00:00,  5.19it/s]


Epoch 8, Loss: 3.2465312433487523, Accuracy: 0.1515515199368338


Valid Iter: 008/010  Loss: 2.2980: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 19.66it/s]


Validation Loss: 2.8175437376234265, Accuracy: 0.24
Class Counts: [7, 6, 11, 3, 22, 3, 3, 7, 7, 0, 0, 6, 0, 4, 16, 26, 11, 2, 2, 18, 19, 5, 5, 7, 2, 16, 4, 3, 8, 16, 3, 5, 14, 10, 6, 10, 12, 9, 9, 20, 9, 9, 20, 20, 14, 6, 4, 7, 21, 3]


Train Iter: 009/010  Loss: 3.0081: 100%|[32m██████████[0m| 1979/1979 [06:22<00:00,  5.17it/s]


Epoch 9, Loss: 3.210858778639175, Accuracy: 0.15887879984208447


Valid Iter: 009/010  Loss: 2.3781: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 20.67it/s]


Validation Loss: 2.7785152520073786, Accuracy: 0.24666666666666667
Class Counts: [17, 9, 5, 3, 21, 2, 9, 7, 17, 0, 1, 1, 1, 15, 6, 17, 3, 4, 6, 14, 22, 3, 9, 16, 5, 22, 4, 12, 8, 7, 5, 3, 8, 5, 10, 14, 8, 12, 8, 20, 11, 8, 4, 14, 8, 8, 13, 2, 19, 4]


Train Iter: 010/010  Loss: 3.2057: 100%|[32m██████████[0m| 1979/1979 [06:25<00:00,  5.14it/s]


Epoch 10, Loss: 3.184411486208886, Accuracy: 0.16603237268061588


Valid Iter: 010/010  Loss: 1.9690: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 21.01it/s]


Validation Loss: 2.749142740037706, Accuracy: 0.22
Class Counts: [6, 18, 6, 5, 25, 8, 5, 3, 2, 0, 2, 2, 1, 13, 6, 11, 3, 3, 1, 6, 25, 10, 11, 8, 5, 20, 7, 16, 8, 2, 8, 13, 12, 3, 6, 14, 7, 10, 9, 21, 14, 6, 11, 12, 7, 7, 7, 9, 19, 17]
Finished Training


Test Iter: 010/010  Loss: 1.9690: 100%|[31m██████████[0m| 15/15 [00:00<00:00, 20.90it/s]

Test Loss: 1.9690446853637695, Accuracy: 0.2688888888888889
Precision: 0.2634189452805013, Recall: 0.2688888888888889, F1: 0.24570814122970888
Class Counts: [8, 15, 5, 9, 10, 14, 3, 7, 4, 0, 1, 2, 0, 14, 4, 14, 5, 1, 1, 7, 18, 8, 11, 7, 2, 22, 3, 15, 11, 4, 9, 13, 7, 13, 8, 7, 6, 7, 18, 15, 23, 3, 11, 6, 10, 10, 13, 9, 18, 19]
Finished Testing



  _warn_prf(average, modifier, msg_start, len(result))


In [15]:
with open('images/images.pickle', 'rb') as f:
    img_data = pickle.load(f)
class CustomDataset(Dataset):
    def __init__(self, img_data, txt_file, transform=None):
        self.data = img_data
        with open(txt_file, 'r') as f:
            self.labels = f.readlines()
        self.transform = transform
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        img_path, label = self.labels[idx].strip().split()
        img_path = img_path.split('/')
        img = self.data[img_path[1]][img_path[2]]
        img = Image.open(io.BytesIO(img)).convert('RGB')
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(int(label))
        return img, label

transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.2, hue=0),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset_train = CustomDataset(img_data=img_data, txt_file='images/train.txt', transform=transform_train)
dataset_val = CustomDataset(img_data=img_data, txt_file='images/val.txt', transform=transform)
dataset_test = CustomDataset(img_data=img_data, txt_file='images/test.txt', transform=transform)


data_loader_train = DataLoader(dataset_train, batch_size=32, shuffle=True, num_workers=16)
data_loader_val = DataLoader(dataset_val, batch_size=32, shuffle=False, num_workers=16)
data_loader_test = DataLoader(dataset_test, batch_size=32, shuffle=False, num_workers=16)
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        batch_size, C, width, height = x.size()
        
        # 计算查询、键和值
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        
        # 计算注意力权重并应用到值上
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        
        # 调整输出并加上残差连接
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

class CNNModel(nn.Module):
    def __init__(self, num_classes=50):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.attention1 = SelfAttention(256)

        self.fc1 = nn.Linear(256*28*28, 2048)
        self.fc2 = nn.Linear(2048, num_classes)
        self.activation = nn.ReLU()
        self.flatten = nn.Flatten()
        self.pool = nn.MaxPool2d(2, 2)
        # self.dropout25 = nn.Dropout(0.25)
        self.dropout50 = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.pool(self.activation(self.bn1(self.conv1(x))))
        x = self.pool(self.activation(self.bn2(self.conv2(x))))
        x = self.pool(self.activation(self.bn3(self.conv3(x))))
        x = self.attention1(x)
        x = self.dropout50(self.flatten(x))
        x = self.dropout50(self.activation(self.fc1(x)))
        x = self.fc2(x)
        return x


for lr in [1e-4, 1e-5, 1e-6]:
    model = CNNModel(num_classes=50).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.00025)
    num_epochs = 10
    save_path = 'record/Q2_less_layers'
    rec_name = '6layers_2048_dropout50_adamW_0.00025_augmentless'
    model_path = f'{save_path}/model_weight/model_weight_lr{lr}_{rec_name}.pth'
    best_val_loss = float('inf')
    best_model = None
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        pbar = tqdm(data_loader_train, colour='green', total=len(data_loader_train))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Train Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()
        running_loss_avg = running_loss / len(dataset_train)
        running_acc_avg = running_acc / len(dataset_train)
        train_loss.append(running_loss_avg)
        train_acc.append(running_acc_avg)
        print(f"Epoch {epoch+1}, Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")

        model.eval()
        with torch.no_grad():
            class_counts = [0] * 50
            running_loss = 0.0
            running_acc = 0.0
            pbar = tqdm(data_loader_val, colour='red', total=len(data_loader_val))
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predictions = outputs.argmax(dim=1)
                loss = criterion(outputs, labels)
                running_loss += loss.item()*len(labels)
                running_acc += (predictions == labels).float().sum().item()
                pbar.set_description(f'Valid Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
                pbar.update()

                for pred_label in predictions:
                    class_counts[pred_label] += 1

            running_loss_avg = running_loss / len(dataset_val)
            running_acc_avg = running_acc / len(dataset_val)
            val_loss.append(running_loss_avg)
            val_acc.append(running_acc_avg)
            print(f"Validation Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")
            print(f"Class Counts: {class_counts}")
            if running_loss_avg < best_val_loss:
                best_val_loss = running_loss
                best_model = model
                torch.save(model.state_dict(), model_path)
        print("======================================")
    print("Finished Training")


    model.eval()
    with torch.no_grad():
        class_counts = [0] * 50
        running_loss = 0.0
        running_acc = 0.0
        predictions = []
        true_labels = []
        pbar = tqdm(data_loader_test, colour='red', total=len(data_loader_test))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs, 1)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Test Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()

        for pred_label in predictions:
            class_counts[pred_label] += 1
        
        precision = precision_score(true_labels, predictions, average='weighted')
        recall = recall_score(true_labels, predictions, average='weighted')
        f1 = f1_score(true_labels, predictions, average='weighted')
        print(f"Test Loss: {running_loss/len(dataset_test)}, Accuracy: {running_acc/len(dataset_test)}")
        print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
        print(f"Class Counts: {class_counts}")
    print("Finished Testing")

    train_record = {"train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc}
    test_record = {"true_labels": true_labels, "predictions": predictions, "precision": precision, "recall": recall, "f1": f1, "class_counts": class_counts}


    with open(f'{save_path}/train_record/train_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(train_record, f)
    with open(f'{save_path}/test_record/test_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(test_record, f)


Train Iter: 001/010  Loss: 3.3986: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.26it/s]


Epoch 1, Loss: 3.7106129950061413, Accuracy: 0.08521121200157916


Valid Iter: 001/010  Loss: 2.2721: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.24it/s]


Validation Loss: 3.1425390021006265, Accuracy: 0.14666666666666667
Class Counts: [14, 3, 13, 6, 25, 10, 6, 5, 0, 3, 0, 0, 0, 3, 2, 7, 11, 1, 5, 25, 3, 1, 17, 10, 3, 4, 4, 8, 26, 11, 3, 3, 24, 13, 12, 15, 0, 18, 5, 12, 2, 18, 16, 9, 8, 13, 19, 6, 23, 5]


Train Iter: 002/010  Loss: 2.9060: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.26it/s]


Epoch 2, Loss: 3.242185457292113, Accuracy: 0.13347019344650612


Valid Iter: 002/010  Loss: 2.7701: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.63it/s]


Validation Loss: 2.953703474468655, Accuracy: 0.1688888888888889
Class Counts: [9, 11, 26, 3, 45, 15, 4, 1, 0, 0, 0, 0, 0, 14, 7, 12, 1, 10, 16, 8, 27, 0, 3, 6, 24, 1, 1, 4, 5, 2, 1, 2, 2, 16, 13, 19, 6, 5, 3, 11, 7, 23, 8, 7, 10, 16, 25, 9, 8, 4]


Train Iter: 003/010  Loss: 3.2049: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 3, Loss: 3.1072938430483426, Accuracy: 0.15756810106592972


Valid Iter: 003/010  Loss: 1.0536: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.85it/s]


Validation Loss: 2.7748839902877807, Accuracy: 0.23777777777777778
Class Counts: [10, 3, 15, 6, 29, 14, 1, 3, 1, 1, 0, 0, 0, 9, 7, 27, 2, 0, 8, 15, 13, 7, 4, 13, 6, 6, 3, 13, 5, 10, 1, 15, 17, 18, 2, 18, 13, 4, 10, 26, 8, 4, 3, 19, 8, 13, 13, 2, 15, 10]


Train Iter: 004/010  Loss: 2.6992: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 4, Loss: 3.0013034530689344, Accuracy: 0.1787761547572049


Valid Iter: 004/010  Loss: 1.2036: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.52it/s]


Validation Loss: 2.732317868338691, Accuracy: 0.24666666666666667
Class Counts: [13, 7, 24, 12, 29, 23, 4, 7, 1, 0, 1, 0, 0, 8, 8, 20, 4, 4, 0, 10, 16, 8, 5, 15, 6, 8, 12, 8, 3, 6, 2, 6, 17, 20, 1, 15, 5, 12, 13, 5, 5, 11, 6, 10, 10, 22, 10, 0, 8, 10]


Train Iter: 005/010  Loss: 2.7850: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.26it/s]


Epoch 5, Loss: 2.9136688439760383, Accuracy: 0.1978523489932886


Valid Iter: 005/010  Loss: 1.1194: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.49it/s]


Validation Loss: 2.595430621571011, Accuracy: 0.2822222222222222
Class Counts: [9, 14, 1, 4, 34, 27, 5, 2, 9, 0, 2, 0, 0, 16, 9, 12, 4, 1, 10, 5, 23, 3, 12, 10, 13, 6, 5, 8, 5, 2, 7, 4, 12, 15, 4, 8, 9, 6, 11, 26, 4, 7, 6, 19, 8, 7, 19, 3, 15, 9]


Train Iter: 006/010  Loss: 2.3742: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 6, Loss: 2.846749017405218, Accuracy: 0.21053296486379786


Valid Iter: 006/010  Loss: 1.1840: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.90it/s]


Validation Loss: 2.5854694128036497, Accuracy: 0.2777777777777778
Class Counts: [4, 9, 24, 26, 17, 9, 5, 5, 16, 2, 0, 2, 2, 8, 6, 5, 4, 24, 29, 2, 7, 1, 3, 3, 12, 17, 7, 11, 6, 0, 2, 5, 4, 9, 12, 12, 3, 9, 6, 13, 9, 16, 1, 14, 12, 16, 9, 6, 17, 9]


Train Iter: 007/010  Loss: 3.0349: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.25it/s]


Epoch 7, Loss: 2.7857188137068, Accuracy: 0.22771417291748913


Valid Iter: 007/010  Loss: 1.1822: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.80it/s]


Validation Loss: 2.501669256952074, Accuracy: 0.30444444444444446
Class Counts: [10, 2, 15, 3, 26, 21, 11, 8, 12, 2, 0, 1, 1, 11, 16, 4, 4, 3, 16, 21, 22, 4, 13, 14, 4, 11, 4, 8, 7, 2, 2, 6, 9, 9, 10, 5, 4, 8, 12, 9, 8, 13, 2, 24, 6, 4, 15, 2, 19, 7]


Train Iter: 008/010  Loss: 2.4327: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 8, Loss: 2.7251919647304144, Accuracy: 0.23846821950256614


Valid Iter: 008/010  Loss: 0.6300: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.83it/s]


Validation Loss: 2.5091986367437573, Accuracy: 0.30444444444444446
Class Counts: [4, 1, 10, 6, 22, 10, 6, 7, 7, 0, 1, 0, 2, 20, 11, 9, 5, 8, 7, 16, 8, 7, 12, 8, 11, 15, 7, 18, 7, 8, 3, 6, 7, 10, 14, 16, 10, 7, 11, 8, 10, 12, 2, 20, 9, 11, 13, 6, 11, 11]


Train Iter: 009/010  Loss: 2.3522: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 9, Loss: 2.668307285158238, Accuracy: 0.2511962100276352


Valid Iter: 009/010  Loss: 1.0042: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 11.12it/s]


Validation Loss: 2.400941181182861, Accuracy: 0.31333333333333335
Class Counts: [16, 14, 19, 1, 32, 9, 9, 21, 13, 0, 5, 0, 3, 7, 7, 20, 5, 7, 8, 8, 2, 4, 4, 5, 15, 17, 9, 6, 9, 0, 4, 6, 7, 9, 11, 10, 7, 7, 8, 14, 9, 8, 5, 9, 10, 9, 17, 6, 10, 9]


Train Iter: 010/010  Loss: 2.6233: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.26it/s]


Epoch 10, Loss: 2.618602939941551, Accuracy: 0.26356099486774576


Valid Iter: 010/010  Loss: 0.5419: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.89it/s]


Validation Loss: 2.3866857157813177, Accuracy: 0.32
Class Counts: [6, 9, 2, 5, 19, 20, 4, 8, 8, 1, 0, 1, 2, 16, 6, 5, 9, 3, 10, 16, 12, 3, 9, 21, 15, 19, 6, 13, 5, 0, 3, 7, 10, 10, 9, 15, 10, 6, 11, 19, 9, 12, 6, 13, 9, 10, 14, 4, 10, 10]
Finished Training


Test Iter: 010/010  Loss: 0.5419: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.71it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Test Loss: 0.5419027805328369, Accuracy: 0.3511111111111111
Precision: 0.393341902249797, Recall: 0.3511111111111111, F1: 0.3358855826944329
Class Counts: [11, 18, 1, 4, 12, 20, 4, 5, 11, 2, 0, 1, 3, 26, 7, 5, 7, 4, 12, 10, 7, 3, 11, 20, 11, 19, 4, 11, 8, 1, 4, 8, 10, 11, 5, 10, 9, 6, 13, 18, 3, 16, 6, 13, 7, 11, 15, 7, 8, 12]
Finished Testing


Train Iter: 001/010  Loss: 3.0237: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.26it/s]


Epoch 1, Loss: 3.4224574536972256, Accuracy: 0.11902092380576391


Valid Iter: 001/010  Loss: 1.8054: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 11.09it/s]


Validation Loss: 2.9581050973468357, Accuracy: 0.24444444444444444
Class Counts: [7, 12, 20, 32, 10, 6, 9, 0, 2, 1, 0, 0, 2, 5, 12, 16, 0, 4, 14, 2, 12, 1, 12, 5, 6, 4, 0, 6, 13, 18, 8, 11, 10, 9, 13, 13, 11, 10, 3, 12, 10, 14, 7, 10, 19, 16, 20, 3, 12, 8]


Train Iter: 002/010  Loss: 2.8477: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.25it/s]


Epoch 2, Loss: 3.118067568068536, Accuracy: 0.17215949467035135


Valid Iter: 002/010  Loss: 1.7435: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.86it/s]


Validation Loss: 2.814775855806139, Accuracy: 0.24222222222222223
Class Counts: [5, 3, 15, 7, 19, 5, 7, 29, 1, 1, 1, 1, 3, 4, 25, 6, 1, 2, 6, 14, 11, 7, 18, 5, 13, 5, 1, 6, 3, 13, 2, 9, 4, 19, 19, 13, 2, 19, 14, 13, 7, 8, 8, 8, 12, 13, 18, 10, 9, 6]


Train Iter: 003/010  Loss: 3.1728: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 3, Loss: 2.9610456970695256, Accuracy: 0.20615870509277537


Valid Iter: 003/010  Loss: 1.0538: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 11.31it/s]


Validation Loss: 2.67418763478597, Accuracy: 0.2733333333333333
Class Counts: [5, 15, 4, 6, 22, 4, 6, 4, 5, 0, 1, 0, 4, 1, 21, 10, 11, 2, 12, 7, 20, 4, 10, 4, 14, 10, 0, 11, 9, 12, 3, 6, 16, 16, 7, 10, 9, 8, 17, 14, 13, 8, 5, 16, 15, 8, 11, 11, 13, 10]


Train Iter: 004/010  Loss: 2.9119: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 4, Loss: 2.8301086239790285, Accuracy: 0.23281484405842873


Valid Iter: 004/010  Loss: 1.1765: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 11.05it/s]


Validation Loss: 2.5718461667166816, Accuracy: 0.2733333333333333
Class Counts: [3, 6, 11, 4, 20, 7, 6, 10, 8, 2, 3, 1, 5, 6, 11, 6, 6, 8, 5, 19, 14, 5, 19, 9, 2, 12, 5, 8, 7, 20, 3, 4, 3, 14, 19, 7, 2, 13, 10, 7, 13, 14, 10, 17, 15, 11, 10, 8, 14, 8]


Train Iter: 005/010  Loss: 2.5655: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.26it/s]


Epoch 5, Loss: 2.720182644982984, Accuracy: 0.25716541650217134


Valid Iter: 005/010  Loss: 0.5227: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.57it/s]


Validation Loss: 2.491911267174615, Accuracy: 0.32666666666666666
Class Counts: [4, 6, 11, 10, 16, 7, 8, 12, 4, 2, 7, 0, 3, 11, 23, 11, 10, 6, 5, 18, 9, 11, 13, 6, 3, 16, 6, 10, 4, 5, 0, 11, 9, 9, 9, 4, 9, 9, 10, 15, 12, 7, 9, 13, 13, 9, 11, 3, 17, 14]


Train Iter: 006/010  Loss: 2.5733: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.26it/s]


Epoch 6, Loss: 2.6204818748401832, Accuracy: 0.2821002763521516


Valid Iter: 006/010  Loss: 0.3930: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 11.21it/s]


Validation Loss: 2.4288120946619247, Accuracy: 0.33111111111111113
Class Counts: [8, 6, 5, 7, 15, 4, 6, 8, 10, 3, 3, 1, 3, 5, 14, 11, 7, 4, 12, 20, 6, 8, 6, 6, 7, 10, 6, 10, 19, 5, 4, 9, 10, 8, 9, 11, 5, 6, 13, 10, 19, 11, 7, 22, 13, 9, 13, 5, 17, 14]


Train Iter: 007/010  Loss: 2.3857: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 7, Loss: 2.52653309786315, Accuracy: 0.3056296881168575


Valid Iter: 007/010  Loss: 0.3565: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 11.59it/s]


Validation Loss: 2.340601024892595, Accuracy: 0.35555555555555557
Class Counts: [10, 4, 12, 8, 15, 6, 9, 8, 5, 4, 6, 0, 4, 5, 13, 5, 3, 9, 13, 15, 13, 7, 15, 5, 5, 15, 3, 11, 6, 11, 2, 11, 8, 10, 8, 13, 9, 11, 11, 17, 15, 4, 5, 19, 9, 15, 13, 3, 11, 11]


Train Iter: 008/010  Loss: 2.5737: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 8, Loss: 2.4395700430954768, Accuracy: 0.32696407422029217


Valid Iter: 008/010  Loss: 0.2454: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.73it/s]


Validation Loss: 2.3364121113883125, Accuracy: 0.3422222222222222
Class Counts: [8, 2, 12, 10, 14, 4, 4, 14, 5, 5, 1, 0, 6, 6, 19, 12, 6, 11, 12, 28, 8, 6, 11, 9, 7, 9, 2, 7, 5, 13, 2, 7, 7, 14, 10, 10, 7, 12, 9, 15, 16, 3, 3, 18, 10, 13, 9, 4, 13, 12]


Train Iter: 009/010  Loss: 2.1157: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.26it/s]


Epoch 9, Loss: 2.3679622554929654, Accuracy: 0.3424555862613502


Valid Iter: 009/010  Loss: 0.2298: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.92it/s]


Validation Loss: 2.273617767095566, Accuracy: 0.36666666666666664
Class Counts: [6, 4, 13, 6, 17, 6, 10, 11, 9, 0, 5, 2, 7, 5, 14, 10, 9, 18, 8, 12, 11, 13, 7, 8, 9, 11, 2, 13, 7, 8, 7, 4, 8, 9, 7, 7, 11, 9, 13, 17, 10, 7, 3, 14, 8, 10, 13, 4, 14, 14]


Train Iter: 010/010  Loss: 1.8568: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.26it/s]


Epoch 10, Loss: 2.2860480521612847, Accuracy: 0.36306356099486775


Valid Iter: 010/010  Loss: 0.2604: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.60it/s]


Validation Loss: 2.1742283548249137, Accuracy: 0.3844444444444444
Class Counts: [9, 5, 16, 6, 16, 9, 8, 12, 7, 4, 6, 2, 5, 10, 10, 8, 6, 7, 9, 22, 9, 12, 12, 5, 7, 9, 6, 11, 9, 5, 6, 9, 8, 6, 9, 6, 11, 6, 11, 10, 16, 9, 13, 10, 8, 13, 9, 5, 12, 11]
Finished Training


Test Iter: 010/010  Loss: 0.2604: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.59it/s]


Test Loss: 0.2604427933692932, Accuracy: 0.38222222222222224
Precision: 0.3783242868242868, Recall: 0.38222222222222224, F1: 0.36899281655662064
Class Counts: [9, 11, 14, 9, 12, 14, 8, 14, 10, 1, 2, 1, 5, 14, 8, 7, 8, 3, 18, 13, 8, 7, 13, 8, 7, 8, 5, 9, 12, 8, 7, 8, 5, 12, 10, 3, 9, 7, 10, 7, 15, 8, 9, 9, 11, 12, 8, 11, 9, 14]
Finished Testing


Train Iter: 001/010  Loss: 3.3494: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.25it/s]


Epoch 1, Loss: 3.5500541147035523, Accuracy: 0.09647058823529411


Valid Iter: 001/010  Loss: 1.6758: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.74it/s]


Validation Loss: 3.188329906993442, Accuracy: 0.1688888888888889
Class Counts: [11, 12, 6, 0, 46, 8, 2, 3, 1, 0, 0, 0, 0, 3, 18, 2, 0, 18, 0, 1, 15, 4, 20, 1, 5, 0, 2, 20, 3, 9, 3, 0, 2, 15, 38, 5, 3, 6, 0, 14, 32, 3, 1, 14, 9, 9, 11, 5, 34, 36]


Train Iter: 002/010  Loss: 3.3871: 100%|[32m██████████[0m| 1979/1979 [02:56<00:00, 11.24it/s]


Epoch 2, Loss: 3.2604653354092075, Accuracy: 0.14649822345045402


Valid Iter: 002/010  Loss: 1.6608: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.75it/s]


Validation Loss: 3.040024666256375, Accuracy: 0.17777777777777778
Class Counts: [6, 7, 26, 2, 16, 13, 8, 8, 1, 0, 7, 0, 0, 3, 3, 7, 5, 7, 7, 13, 6, 2, 6, 4, 3, 9, 0, 12, 5, 15, 1, 2, 0, 35, 39, 17, 1, 12, 4, 8, 13, 16, 20, 3, 11, 7, 7, 4, 29, 20]


Train Iter: 003/010  Loss: 3.1614: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 3, Loss: 3.138892748912144, Accuracy: 0.17119621002763522


Valid Iter: 003/010  Loss: 1.8367: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.96it/s]


Validation Loss: 2.94463211218516, Accuracy: 0.2222222222222222
Class Counts: [4, 14, 7, 0, 28, 6, 3, 11, 14, 6, 4, 0, 3, 1, 23, 8, 2, 12, 14, 4, 13, 8, 15, 3, 11, 6, 4, 10, 6, 6, 1, 0, 3, 26, 38, 3, 1, 6, 4, 23, 10, 5, 2, 16, 7, 12, 12, 8, 17, 10]


Train Iter: 004/010  Loss: 2.8282: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 4, Loss: 3.0617960622345466, Accuracy: 0.18719305171733122


Valid Iter: 004/010  Loss: 1.1609: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.67it/s]


Validation Loss: 2.893424654006958, Accuracy: 0.22666666666666666
Class Counts: [8, 8, 16, 9, 28, 4, 5, 7, 2, 2, 1, 0, 2, 11, 6, 17, 1, 6, 11, 4, 6, 11, 8, 3, 5, 9, 0, 14, 4, 11, 1, 4, 10, 24, 17, 12, 6, 17, 16, 20, 16, 7, 2, 12, 7, 5, 10, 3, 19, 23]


Train Iter: 005/010  Loss: 2.9424: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.26it/s]


Epoch 5, Loss: 2.998263277440875, Accuracy: 0.20213185945519146


Valid Iter: 005/010  Loss: 1.3177: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 11.12it/s]


Validation Loss: 2.834703960418701, Accuracy: 0.25333333333333335
Class Counts: [4, 18, 13, 6, 18, 9, 11, 5, 2, 7, 0, 0, 3, 16, 2, 15, 10, 5, 6, 14, 11, 6, 17, 4, 4, 8, 1, 11, 7, 13, 6, 0, 10, 25, 13, 9, 6, 8, 9, 19, 8, 9, 5, 11, 10, 11, 8, 4, 16, 17]


Train Iter: 006/010  Loss: 2.7055: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 6, Loss: 2.9494135752601487, Accuracy: 0.21184366363995263


Valid Iter: 006/010  Loss: 1.1984: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.96it/s]


Validation Loss: 2.7739552042219375, Accuracy: 0.2733333333333333
Class Counts: [4, 25, 5, 5, 25, 7, 7, 5, 2, 1, 1, 0, 3, 1, 5, 11, 6, 7, 7, 12, 10, 6, 9, 3, 7, 8, 1, 14, 9, 11, 7, 6, 7, 24, 14, 7, 9, 10, 11, 9, 11, 14, 16, 21, 10, 11, 8, 5, 17, 16]


Train Iter: 007/010  Loss: 2.4659: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.25it/s]


Epoch 7, Loss: 2.9041956727569676, Accuracy: 0.223497828661666


Valid Iter: 007/010  Loss: 1.3442: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.74it/s]


Validation Loss: 2.7557923857371014, Accuracy: 0.2688888888888889
Class Counts: [7, 8, 10, 5, 26, 7, 8, 9, 7, 2, 1, 0, 4, 3, 13, 12, 12, 10, 13, 17, 12, 3, 8, 8, 3, 11, 1, 6, 12, 7, 6, 3, 7, 11, 18, 4, 6, 8, 9, 8, 13, 11, 8, 23, 9, 10, 11, 9, 16, 15]


Train Iter: 008/010  Loss: 2.9776: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.26it/s]


Epoch 8, Loss: 2.8648678225318855, Accuracy: 0.2316462692459534


Valid Iter: 008/010  Loss: 0.9917: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 11.49it/s]


Validation Loss: 2.7181041158570185, Accuracy: 0.26222222222222225
Class Counts: [4, 4, 24, 3, 20, 9, 7, 7, 4, 2, 1, 0, 5, 11, 6, 5, 5, 10, 4, 11, 13, 12, 12, 12, 6, 9, 3, 8, 6, 9, 7, 4, 11, 23, 10, 7, 9, 9, 9, 12, 19, 11, 2, 17, 7, 12, 11, 2, 18, 18]


Train Iter: 009/010  Loss: 2.7692: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.26it/s]


Epoch 9, Loss: 2.830897209793976, Accuracy: 0.24045795499407815


Valid Iter: 009/010  Loss: 1.0856: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 11.03it/s]


Validation Loss: 2.712172197235955, Accuracy: 0.28444444444444444
Class Counts: [4, 7, 13, 11, 16, 2, 7, 9, 7, 4, 1, 1, 5, 5, 5, 10, 5, 13, 11, 19, 19, 8, 6, 10, 4, 13, 2, 10, 1, 2, 6, 6, 8, 11, 19, 10, 6, 9, 14, 13, 11, 11, 11, 19, 15, 10, 14, 5, 15, 7]


Train Iter: 010/010  Loss: 2.9024: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 10, Loss: 2.794634328654666, Accuracy: 0.24709040663245163


Valid Iter: 010/010  Loss: 0.6913: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.67it/s]


Validation Loss: 2.665474801858266, Accuracy: 0.27555555555555555
Class Counts: [3, 5, 17, 5, 21, 13, 9, 6, 12, 1, 5, 1, 6, 4, 5, 6, 2, 8, 5, 14, 8, 9, 15, 5, 3, 11, 1, 13, 5, 5, 5, 8, 5, 18, 15, 11, 18, 8, 10, 11, 18, 13, 1, 17, 11, 10, 7, 3, 17, 21]
Finished Training


Test Iter: 010/010  Loss: 0.6913: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.91it/s]


Test Loss: 0.6913202404975891, Accuracy: 0.29555555555555557
Precision: 0.3302085315785006, Recall: 0.29555555555555557, F1: 0.28583110524280264
Class Counts: [12, 8, 14, 10, 9, 14, 4, 6, 12, 1, 5, 1, 6, 3, 2, 10, 6, 9, 10, 6, 9, 7, 19, 11, 2, 8, 8, 17, 7, 9, 2, 4, 5, 21, 15, 11, 15, 5, 7, 14, 17, 11, 2, 7, 9, 17, 9, 7, 9, 18]
Finished Testing


In [20]:
with open('images/images.pickle', 'rb') as f:
    img_data = pickle.load(f)
class CustomDataset(Dataset):
    def __init__(self, img_data, txt_file, transform=None):
        self.data = img_data
        with open(txt_file, 'r') as f:
            self.labels = f.readlines()
        self.transform = transform
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        img_path, label = self.labels[idx].strip().split()
        img_path = img_path.split('/')
        img = self.data[img_path[1]][img_path[2]]
        img = Image.open(io.BytesIO(img)).convert('RGB')
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(int(label))
        return img, label

transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.2, hue=0),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset_train = CustomDataset(img_data=img_data, txt_file='images/train.txt', transform=transform_train)
dataset_val = CustomDataset(img_data=img_data, txt_file='images/val.txt', transform=transform)
dataset_test = CustomDataset(img_data=img_data, txt_file='images/test.txt', transform=transform)


data_loader_train = DataLoader(dataset_train, batch_size=32, shuffle=True, num_workers=16)
data_loader_val = DataLoader(dataset_val, batch_size=32, shuffle=False, num_workers=16)
data_loader_test = DataLoader(dataset_test, batch_size=32, shuffle=False, num_workers=16)
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        batch_size, C, width, height = x.size()
        
        # 计算查询、键和值
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        
        # 计算注意力权重并应用到值上
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        
        # 调整输出并加上残差连接
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

class CNNModel(nn.Module):
    def __init__(self, num_classes=50):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.attention1 = SelfAttention(256)

        self.fc1 = nn.Linear(256*28*28, 2048)
        self.fc2 = nn.Linear(2048, num_classes)
        self.activation = nn.ReLU()
        self.flatten = nn.Flatten()
        self.pool = nn.MaxPool2d(2, 2)
        # self.dropout25 = nn.Dropout(0.25)
        self.dropout50 = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.pool(self.activation(self.bn1(self.conv1(x))))
        x = self.pool(self.activation(self.bn2(self.conv2(x))))
        x = self.pool(self.activation(self.bn3(self.conv3(x))))
        x = self.attention1(x)
        x = self.dropout50(self.flatten(x))
        x = self.dropout50(self.activation(self.fc1(x)))
        x = self.fc2(x)
        return x


for lr in [1e-5]:
    model = CNNModel(num_classes=50).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.00025)
    num_epochs = 10
    save_path = 'record/Q2_less_layers'
    rec_name = '6layers_2048_dropout50_adamW_0.00025_augmentless_round2'
    model_path = f'{save_path}/model_weight/model_weight_lr{lr}_{rec_name}.pth'
    best_val_loss = float('inf')
    best_model = None
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        pbar = tqdm(data_loader_train, colour='green', total=len(data_loader_train))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Train Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()
        running_loss_avg = running_loss / len(dataset_train)
        running_acc_avg = running_acc / len(dataset_train)
        train_loss.append(running_loss_avg)
        train_acc.append(running_acc_avg)
        print(f"Epoch {epoch+1}, Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")

        model.eval()
        with torch.no_grad():
            class_counts = [0] * 50
            running_loss = 0.0
            running_acc = 0.0
            pbar = tqdm(data_loader_val, colour='red', total=len(data_loader_val))
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predictions = outputs.argmax(dim=1)
                loss = criterion(outputs, labels)
                running_loss += loss.item()*len(labels)
                running_acc += (predictions == labels).float().sum().item()
                pbar.set_description(f'Valid Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
                pbar.update()

                for pred_label in predictions:
                    class_counts[pred_label] += 1

            running_loss_avg = running_loss / len(dataset_val)
            running_acc_avg = running_acc / len(dataset_val)
            val_loss.append(running_loss_avg)
            val_acc.append(running_acc_avg)
            print(f"Validation Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")
            print(f"Class Counts: {class_counts}")
            if running_loss_avg < best_val_loss:
                best_val_loss = running_loss
                best_model = model
                torch.save(model.state_dict(), model_path)
        print("======================================")
    print("Finished Training")


    model.eval()
    with torch.no_grad():
        class_counts = [0] * 50
        running_loss = 0.0
        running_acc = 0.0
        predictions = []
        true_labels = []
        pbar = tqdm(data_loader_test, colour='red', total=len(data_loader_test))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs, 1)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            loss = criterion(outputs, labels)

            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Test Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()

        for pred_label in predictions:
            class_counts[pred_label] += 1
        
        precision = precision_score(true_labels, predictions, average='weighted')
        recall = recall_score(true_labels, predictions, average='weighted')
        f1 = f1_score(true_labels, predictions, average='weighted')
        print(f"Test Loss: {running_loss/len(dataset_test)}, Accuracy: {running_acc/len(dataset_test)}")
        print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
        print(f"Class Counts: {class_counts}")
    print("Finished Testing")

    train_record = {"train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc}
    test_record = {"true_labels": true_labels, "predictions": predictions, "precision": precision, "recall": recall, "f1": f1, "class_counts": class_counts}


    with open(f'{save_path}/train_record/train_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(train_record, f)
    with open(f'{save_path}/test_record/test_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(test_record, f)


Train Iter: 001/010  Loss: 3.1077: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 1, Loss: 3.4416097706803397, Accuracy: 0.11617844453217528


Valid Iter: 001/010  Loss: 2.2201: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 11.09it/s]


Validation Loss: 3.011162149641249, Accuracy: 0.19111111111111112
Class Counts: [8, 6, 22, 11, 28, 2, 2, 6, 0, 3, 0, 5, 2, 7, 7, 13, 0, 11, 13, 1, 21, 2, 18, 3, 2, 6, 1, 1, 25, 3, 8, 3, 5, 23, 21, 3, 4, 11, 7, 21, 8, 7, 5, 17, 28, 6, 12, 3, 18, 11]


Train Iter: 002/010  Loss: 2.9394: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.30it/s]


Epoch 2, Loss: 3.1242212940175405, Accuracy: 0.17146466640347413


Valid Iter: 002/010  Loss: 1.6020: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.93it/s]


Validation Loss: 2.847156081729465, Accuracy: 0.23333333333333334
Class Counts: [8, 15, 19, 7, 29, 7, 1, 9, 1, 5, 1, 1, 1, 12, 8, 18, 0, 6, 5, 22, 15, 8, 6, 3, 1, 6, 2, 11, 5, 18, 4, 5, 9, 19, 13, 1, 6, 10, 14, 17, 6, 9, 2, 15, 10, 8, 13, 3, 21, 15]


Train Iter: 003/010  Loss: 2.7333: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.29it/s]


Epoch 3, Loss: 2.9659936129317375, Accuracy: 0.20251085669166996


Valid Iter: 003/010  Loss: 1.2474: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.66it/s]


Validation Loss: 2.6923208469814726, Accuracy: 0.23555555555555555
Class Counts: [11, 11, 14, 16, 13, 9, 13, 6, 1, 1, 1, 0, 4, 6, 7, 9, 7, 14, 10, 10, 13, 16, 4, 8, 4, 14, 3, 9, 22, 3, 3, 3, 3, 12, 16, 9, 14, 7, 8, 11, 12, 7, 3, 20, 10, 13, 9, 2, 15, 14]


Train Iter: 004/010  Loss: 2.6332: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.29it/s]


Epoch 4, Loss: 2.8303231872554075, Accuracy: 0.23409395973154362


Valid Iter: 004/010  Loss: 0.8462: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.89it/s]


Validation Loss: 2.599440775182512, Accuracy: 0.28444444444444444
Class Counts: [13, 27, 9, 8, 29, 6, 3, 6, 5, 2, 7, 1, 3, 7, 6, 7, 3, 12, 12, 12, 14, 11, 6, 7, 13, 9, 6, 8, 8, 4, 3, 8, 10, 11, 13, 4, 8, 8, 12, 13, 12, 9, 2, 13, 8, 11, 11, 0, 13, 17]


Train Iter: 005/010  Loss: 2.9287: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.29it/s]


Epoch 5, Loss: 2.7166750565245366, Accuracy: 0.2578760363205685


Valid Iter: 005/010  Loss: 0.5497: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.26it/s]


Validation Loss: 2.5050910107294717, Accuracy: 0.30666666666666664
Class Counts: [9, 4, 17, 9, 15, 3, 9, 6, 7, 5, 2, 0, 3, 7, 6, 14, 2, 3, 15, 17, 12, 9, 6, 8, 9, 13, 2, 6, 4, 10, 2, 3, 6, 19, 14, 10, 7, 13, 15, 16, 12, 6, 6, 25, 12, 11, 7, 2, 15, 17]


Train Iter: 006/010  Loss: 2.6706: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.29it/s]


Epoch 6, Loss: 2.614846004854301, Accuracy: 0.2830477694433478


Valid Iter: 006/010  Loss: 0.3358: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.81it/s]


Validation Loss: 2.4175664216942256, Accuracy: 0.35777777777777775
Class Counts: [6, 6, 9, 7, 16, 3, 4, 17, 7, 6, 8, 1, 3, 6, 12, 16, 5, 11, 14, 15, 20, 5, 2, 8, 9, 14, 3, 5, 7, 9, 7, 9, 10, 10, 4, 8, 10, 9, 7, 18, 18, 7, 3, 17, 6, 10, 15, 1, 11, 16]


Train Iter: 007/010  Loss: 2.9512: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.28it/s]


Epoch 7, Loss: 2.527346330673573, Accuracy: 0.301997631267272


Valid Iter: 007/010  Loss: 0.2339: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.73it/s]


Validation Loss: 2.329976656768057, Accuracy: 0.35333333333333333
Class Counts: [4, 7, 11, 8, 17, 6, 7, 8, 10, 2, 5, 1, 4, 9, 8, 10, 5, 5, 14, 18, 9, 5, 16, 8, 5, 16, 6, 15, 7, 3, 8, 16, 9, 7, 5, 7, 7, 11, 11, 15, 13, 10, 5, 17, 7, 11, 13, 1, 15, 13]


Train Iter: 008/010  Loss: 2.6776: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.28it/s]


Epoch 8, Loss: 2.4458274151372477, Accuracy: 0.3212001579155152


Valid Iter: 008/010  Loss: 0.4100: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.96it/s]


Validation Loss: 2.2591083598136903, Accuracy: 0.3711111111111111
Class Counts: [8, 7, 9, 9, 12, 9, 5, 9, 4, 1, 10, 3, 2, 7, 10, 9, 3, 7, 8, 10, 12, 8, 20, 10, 6, 14, 3, 10, 12, 10, 10, 7, 6, 11, 9, 8, 11, 10, 8, 16, 15, 9, 5, 15, 8, 12, 14, 11, 12, 6]


Train Iter: 009/010  Loss: 2.6728: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.29it/s]


Epoch 9, Loss: 2.3662260252940905, Accuracy: 0.3406553493880774


Valid Iter: 009/010  Loss: 0.3478: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.50it/s]


Validation Loss: 2.234450673262278, Accuracy: 0.37777777777777777
Class Counts: [9, 7, 10, 13, 10, 6, 7, 7, 6, 3, 7, 1, 5, 8, 10, 10, 5, 7, 8, 21, 6, 4, 23, 9, 8, 17, 5, 9, 6, 10, 6, 10, 7, 11, 7, 6, 7, 12, 11, 16, 14, 4, 5, 16, 6, 8, 9, 8, 16, 14]


Train Iter: 010/010  Loss: 2.2446: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.29it/s]


Epoch 10, Loss: 2.2926392355578904, Accuracy: 0.35976312672720095


Valid Iter: 010/010  Loss: 0.1628: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.61it/s]


Validation Loss: 2.214882679051823, Accuracy: 0.3844444444444444
Class Counts: [4, 10, 13, 8, 12, 4, 6, 4, 5, 4, 5, 4, 5, 9, 8, 14, 7, 13, 5, 15, 12, 7, 7, 12, 12, 16, 4, 3, 9, 11, 5, 8, 7, 15, 7, 9, 13, 14, 9, 13, 12, 7, 3, 19, 7, 11, 8, 4, 14, 17]
Finished Training


Test Iter: 010/010  Loss: 1.6552: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.71it/s]

Test Loss: 2.238057470851474, Accuracy: 0.38
Precision: 0.3902727827727828, Recall: 0.38, F1: 0.3635531224706753
Class Counts: [8, 11, 8, 5, 10, 15, 3, 7, 4, 2, 4, 4, 7, 14, 5, 10, 6, 9, 18, 14, 10, 3, 9, 15, 9, 10, 5, 7, 2, 12, 4, 6, 10, 21, 9, 12, 4, 11, 9, 10, 13, 13, 3, 18, 10, 15, 10, 9, 8, 9]
Finished Testing





In [21]:
with open('images/images.pickle', 'rb') as f:
    img_data = pickle.load(f)
class CustomDataset(Dataset):
    def __init__(self, img_data, txt_file, transform=None):
        self.data = img_data
        with open(txt_file, 'r') as f:
            self.labels = f.readlines()
        self.transform = transform
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        img_path, label = self.labels[idx].strip().split()
        img_path = img_path.split('/')
        img = self.data[img_path[1]][img_path[2]]
        img = Image.open(io.BytesIO(img)).convert('RGB')
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(int(label))
        return img, label

transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.2, hue=0),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset_train = CustomDataset(img_data=img_data, txt_file='images/train.txt', transform=transform_train)
dataset_val = CustomDataset(img_data=img_data, txt_file='images/val.txt', transform=transform)
dataset_test = CustomDataset(img_data=img_data, txt_file='images/test.txt', transform=transform)


data_loader_train = DataLoader(dataset_train, batch_size=32, shuffle=True, num_workers=16)
data_loader_val = DataLoader(dataset_val, batch_size=32, shuffle=False, num_workers=16)
data_loader_test = DataLoader(dataset_test, batch_size=32, shuffle=False, num_workers=16)
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        batch_size, C, width, height = x.size()
        
        # 计算查询、键和值
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        
        # 计算注意力权重并应用到值上
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        
        # 调整输出并加上残差连接
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

class CNNModel(nn.Module):
    def __init__(self, num_classes=50):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.attention1 = SelfAttention(256)

        self.fc1 = nn.Linear(256*28*28, 2048)
        self.fc2 = nn.Linear(2048, num_classes)
        self.activation = nn.ReLU()
        self.flatten = nn.Flatten()
        self.pool = nn.MaxPool2d(2, 2)
        # self.dropout25 = nn.Dropout(0.25)
        self.dropout50 = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.pool(self.activation(self.bn1(self.conv1(x))))
        x = self.pool(self.activation(self.bn2(self.conv2(x))))
        x = self.pool(self.activation(self.bn3(self.conv3(x))))
        x = self.attention1(x)
        x = self.dropout50(self.flatten(x))
        x = self.dropout50(self.activation(self.fc1(x)))
        x = self.fc2(x)
        return x


for lr in [1e-5]:
    model = CNNModel(num_classes=50).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.00025)
    num_epochs = 10
    save_path = 'record/Q2_less_layers'
    rec_name = '6layers_2048_dropout50_adamW_0.00025_augmentless_round3'
    model_path = f'{save_path}/model_weight/model_weight_lr{lr}_{rec_name}.pth'
    best_val_loss = float('inf')
    best_model = None
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        pbar = tqdm(data_loader_train, colour='green', total=len(data_loader_train))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Train Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()
        running_loss_avg = running_loss / len(dataset_train)
        running_acc_avg = running_acc / len(dataset_train)
        train_loss.append(running_loss_avg)
        train_acc.append(running_acc_avg)
        print(f"Epoch {epoch+1}, Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")

        model.eval()
        with torch.no_grad():
            class_counts = [0] * 50
            running_loss = 0.0
            running_acc = 0.0
            pbar = tqdm(data_loader_val, colour='red', total=len(data_loader_val))
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predictions = outputs.argmax(dim=1)
                loss = criterion(outputs, labels)
                running_loss += loss.item()*len(labels)
                running_acc += (predictions == labels).float().sum().item()
                pbar.set_description(f'Valid Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
                pbar.update()

                for pred_label in predictions:
                    class_counts[pred_label] += 1

            running_loss_avg = running_loss / len(dataset_val)
            running_acc_avg = running_acc / len(dataset_val)
            val_loss.append(running_loss_avg)
            val_acc.append(running_acc_avg)
            print(f"Validation Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")
            print(f"Class Counts: {class_counts}")
            if running_loss_avg < best_val_loss:
                best_val_loss = running_loss
                best_model = model
                torch.save(model.state_dict(), model_path)
        print("======================================")
    print("Finished Training")


    model.eval()
    with torch.no_grad():
        class_counts = [0] * 50
        running_loss = 0.0
        running_acc = 0.0
        predictions = []
        true_labels = []
        pbar = tqdm(data_loader_test, colour='red', total=len(data_loader_test))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs, 1)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            loss = criterion(outputs, labels)

            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Test Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()

        for pred_label in predictions:
            class_counts[pred_label] += 1
        
        precision = precision_score(true_labels, predictions, average='weighted')
        recall = recall_score(true_labels, predictions, average='weighted')
        f1 = f1_score(true_labels, predictions, average='weighted')
        print(f"Test Loss: {running_loss/len(dataset_test)}, Accuracy: {running_acc/len(dataset_test)}")
        print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
        print(f"Class Counts: {class_counts}")
    print("Finished Testing")

    train_record = {"train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc}
    test_record = {"true_labels": true_labels, "predictions": predictions, "precision": precision, "recall": recall, "f1": f1, "class_counts": class_counts}


    with open(f'{save_path}/train_record/train_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(train_record, f)
    with open(f'{save_path}/test_record/test_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(test_record, f)


Train Iter: 001/010  Loss: 3.0541: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 1, Loss: 3.427572805647129, Accuracy: 0.11941571259376234


Valid Iter: 001/010  Loss: 1.8616: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.70it/s]


Validation Loss: 2.992826567225986, Accuracy: 0.2111111111111111
Class Counts: [5, 17, 13, 8, 42, 9, 1, 4, 1, 0, 0, 0, 0, 2, 3, 9, 5, 3, 1, 0, 16, 14, 19, 7, 1, 13, 6, 13, 3, 7, 2, 7, 17, 34, 15, 5, 3, 4, 10, 23, 6, 10, 7, 14, 15, 8, 12, 5, 15, 16]


Train Iter: 002/010  Loss: 3.3663: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.26it/s]


Epoch 2, Loss: 3.108815958536617, Accuracy: 0.17313857086458745


Valid Iter: 002/010  Loss: 1.7920: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.88it/s]


Validation Loss: 2.776230977376302, Accuracy: 0.23333333333333334
Class Counts: [7, 6, 23, 10, 33, 5, 7, 2, 6, 2, 1, 1, 3, 2, 6, 1, 1, 8, 18, 6, 6, 1, 9, 10, 10, 4, 1, 2, 17, 11, 6, 5, 10, 25, 5, 9, 24, 14, 13, 16, 12, 8, 5, 16, 11, 5, 15, 3, 19, 10]


Train Iter: 003/010  Loss: 2.6661: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.25it/s]


Epoch 3, Loss: 2.951364253850215, Accuracy: 0.20628503750493485


Valid Iter: 003/010  Loss: 1.2922: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 11.37it/s]


Validation Loss: 2.6724882851706613, Accuracy: 0.24666666666666667
Class Counts: [6, 6, 12, 9, 19, 2, 9, 5, 5, 0, 1, 1, 4, 3, 14, 9, 3, 9, 7, 13, 8, 19, 3, 9, 4, 13, 1, 14, 5, 8, 3, 8, 12, 24, 9, 7, 6, 12, 9, 9, 16, 18, 8, 24, 9, 10, 10, 5, 17, 13]


Train Iter: 004/010  Loss: 2.9903: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 4, Loss: 2.8246917091637695, Accuracy: 0.23414133438610343


Valid Iter: 004/010  Loss: 1.0298: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.86it/s]


Validation Loss: 2.5848861763212416, Accuracy: 0.29333333333333333
Class Counts: [7, 6, 14, 3, 18, 5, 6, 11, 4, 3, 0, 2, 4, 1, 8, 21, 7, 4, 13, 21, 6, 21, 0, 12, 6, 15, 1, 7, 8, 7, 4, 18, 10, 16, 7, 5, 10, 5, 12, 16, 8, 4, 3, 24, 11, 13, 10, 11, 17, 5]


Train Iter: 005/010  Loss: 2.6109: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.28it/s]


Epoch 5, Loss: 2.708543711875768, Accuracy: 0.26146071851559416


Valid Iter: 005/010  Loss: 0.8660: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.99it/s]


Validation Loss: 2.4969357665379843, Accuracy: 0.3
Class Counts: [7, 6, 6, 8, 19, 9, 6, 13, 10, 3, 1, 1, 5, 11, 8, 10, 10, 6, 15, 23, 9, 5, 8, 6, 2, 15, 1, 16, 15, 5, 4, 5, 7, 15, 9, 10, 10, 7, 14, 11, 15, 9, 2, 11, 8, 11, 12, 8, 12, 11]


Train Iter: 006/010  Loss: 2.3303: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 6, Loss: 2.6072341236915695, Accuracy: 0.2848322147651007


Valid Iter: 006/010  Loss: 0.5218: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 11.22it/s]


Validation Loss: 2.4437450308269923, Accuracy: 0.31555555555555553
Class Counts: [5, 2, 8, 7, 22, 6, 7, 10, 9, 3, 3, 2, 3, 11, 7, 13, 7, 7, 7, 26, 5, 9, 10, 8, 4, 10, 2, 16, 10, 7, 1, 4, 14, 28, 5, 5, 5, 11, 13, 15, 15, 4, 2, 21, 9, 12, 10, 7, 10, 13]


Train Iter: 007/010  Loss: 2.1744: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.27it/s]


Epoch 7, Loss: 2.5262726030940836, Accuracy: 0.30283458349782866


Valid Iter: 007/010  Loss: 0.6584: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.91it/s]


Validation Loss: 2.349793258772956, Accuracy: 0.3511111111111111
Class Counts: [6, 3, 10, 7, 21, 10, 5, 8, 8, 2, 0, 3, 3, 3, 11, 11, 4, 10, 7, 20, 18, 5, 9, 10, 2, 17, 1, 15, 11, 7, 11, 10, 8, 12, 6, 12, 8, 13, 9, 12, 15, 10, 5, 10, 10, 10, 12, 7, 12, 11]


Train Iter: 008/010  Loss: 2.1341: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.26it/s]


Epoch 8, Loss: 2.43541670050397, Accuracy: 0.3253217528622187


Valid Iter: 008/010  Loss: 0.3473: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.77it/s]


Validation Loss: 2.335539667341444, Accuracy: 0.3466666666666667
Class Counts: [5, 3, 6, 9, 10, 11, 5, 9, 6, 3, 5, 4, 4, 7, 12, 8, 3, 10, 8, 18, 13, 7, 16, 9, 9, 14, 3, 13, 8, 5, 4, 5, 10, 12, 9, 9, 12, 18, 11, 15, 13, 5, 4, 16, 11, 10, 9, 6, 14, 14]


Train Iter: 009/010  Loss: 2.4312: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.26it/s]


Epoch 9, Loss: 2.361083850209251, Accuracy: 0.3446821950256613


Valid Iter: 009/010  Loss: 0.2543: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.65it/s]


Validation Loss: 2.2743116823832192, Accuracy: 0.35333333333333333
Class Counts: [6, 8, 12, 8, 13, 5, 8, 11, 6, 3, 11, 1, 7, 7, 11, 7, 5, 14, 10, 9, 18, 4, 11, 6, 10, 17, 2, 14, 8, 5, 5, 10, 5, 9, 10, 6, 11, 10, 9, 12, 14, 8, 5, 13, 9, 11, 8, 9, 14, 15]


Train Iter: 010/010  Loss: 2.3193: 100%|[32m██████████[0m| 1979/1979 [02:55<00:00, 11.26it/s]


Epoch 10, Loss: 2.2907818063577077, Accuracy: 0.3630161863403079


Valid Iter: 010/010  Loss: 0.2965: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 11.02it/s]


Validation Loss: 2.2528827528158826, Accuracy: 0.36666666666666664
Class Counts: [5, 6, 11, 10, 10, 10, 10, 9, 8, 2, 6, 2, 4, 8, 6, 3, 4, 10, 16, 24, 13, 10, 8, 7, 12, 14, 1, 13, 10, 4, 4, 8, 7, 10, 9, 10, 11, 11, 11, 11, 15, 9, 4, 12, 9, 8, 10, 12, 12, 11]
Finished Training


Test Iter: 010/010  Loss: 0.7742: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 10.92it/s]


Test Loss: 2.217808720800612, Accuracy: 0.38666666666666666
Precision: 0.39547956945015766, Recall: 0.38666666666666666, F1: 0.3733428399327847
Class Counts: [6, 5, 14, 12, 6, 12, 11, 12, 6, 2, 5, 2, 6, 16, 12, 4, 5, 9, 18, 11, 11, 7, 11, 9, 9, 11, 5, 6, 11, 8, 2, 5, 10, 17, 13, 9, 4, 8, 11, 9, 7, 8, 8, 11, 10, 13, 14, 13, 8, 8]
Finished Testing


In [4]:
with open('images/images.pickle', 'rb') as f:
    img_data = pickle.load(f)
class CustomDataset(Dataset):
    def __init__(self, img_data, txt_file, transform=None):
        self.data = img_data
        with open(txt_file, 'r') as f:
            self.labels = f.readlines()
        self.transform = transform
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        img_path, label = self.labels[idx].strip().split()
        img_path = img_path.split('/')
        img = self.data[img_path[1]][img_path[2]]
        img = Image.open(io.BytesIO(img)).convert('RGB')
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(int(label))
        return img, label

transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.2, hue=0),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset_train = CustomDataset(img_data=img_data, txt_file='images/train.txt', transform=transform_train)
dataset_val = CustomDataset(img_data=img_data, txt_file='images/val.txt', transform=transform)
dataset_test = CustomDataset(img_data=img_data, txt_file='images/test.txt', transform=transform)


data_loader_train = DataLoader(dataset_train, batch_size=32, shuffle=True, num_workers=16)
data_loader_val = DataLoader(dataset_val, batch_size=32, shuffle=False, num_workers=16)
data_loader_test = DataLoader(dataset_test, batch_size=32, shuffle=False, num_workers=16)
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")


class CNNModel(nn.Module):
    def __init__(self, num_classes=50):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        # self.attention1 = SelfAttention(256)

        self.fc1 = nn.Linear(256*28*28, 2048)
        self.fc2 = nn.Linear(2048, num_classes)
        self.activation = nn.ReLU()
        self.flatten = nn.Flatten()
        self.pool = nn.MaxPool2d(2, 2)
        # self.dropout25 = nn.Dropout(0.25)
        self.dropout50 = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.pool(self.activation(self.bn1(self.conv1(x))))
        x = self.pool(self.activation(self.bn2(self.conv2(x))))
        x = self.pool(self.activation(self.bn3(self.conv3(x))))
        x = self.dropout50(self.flatten(x))
        x = self.dropout50(self.activation(self.fc1(x)))
        x = self.fc2(x)
        return x


for lr in [1e-5]:
    model = CNNModel(num_classes=50).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.00025)
    num_epochs = 10
    save_path = 'record/Q2_less_layers'
    rec_name = '5layers_2048_dropout50_adamW_0.00025_augmentless_noattention'
    model_path = f'{save_path}/model_weight/model_weight_lr{lr}_{rec_name}.pth'
    best_val_loss = float('inf')
    best_model = None
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        pbar = tqdm(data_loader_train, colour='green', total=len(data_loader_train))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Train Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()
        running_loss_avg = running_loss / len(dataset_train)
        running_acc_avg = running_acc / len(dataset_train)
        train_loss.append(running_loss_avg)
        train_acc.append(running_acc_avg)
        print(f"Epoch {epoch+1}, Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")

        model.eval()
        with torch.no_grad():
            class_counts = [0] * 50
            running_loss = 0.0
            running_acc = 0.0
            pbar = tqdm(data_loader_val, colour='red', total=len(data_loader_val))
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predictions = outputs.argmax(dim=1)
                loss = criterion(outputs, labels)
                running_loss += loss.item()*len(labels)
                running_acc += (predictions == labels).float().sum().item()
                pbar.set_description(f'Valid Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
                pbar.update()

                for pred_label in predictions:
                    class_counts[pred_label] += 1

            running_loss_avg = running_loss / len(dataset_val)
            running_acc_avg = running_acc / len(dataset_val)
            val_loss.append(running_loss_avg)
            val_acc.append(running_acc_avg)
            print(f"Validation Loss: {running_loss_avg}, Accuracy: {running_acc_avg}")
            print(f"Class Counts: {class_counts}")
            if running_loss_avg < best_val_loss:
                best_val_loss = running_loss
                best_model = model
                torch.save(model.state_dict(), model_path)
        print("======================================")
    print("Finished Training")


    model.eval()
    with torch.no_grad():
        class_counts = [0] * 50
        running_loss = 0.0
        running_acc = 0.0
        predictions = []
        true_labels = []
        pbar = tqdm(data_loader_test, colour='red', total=len(data_loader_test))
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs, 1)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            loss = criterion(outputs, labels)

            running_loss += loss.item()*len(labels)
            running_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
            pbar.set_description(f'Test Iter: {epoch+1:03}/{num_epochs:03}  Loss: {loss:3.4f}')
            pbar.update()

        for pred_label in predictions:
            class_counts[pred_label] += 1
        
        precision = precision_score(true_labels, predictions, average='weighted')
        recall = recall_score(true_labels, predictions, average='weighted')
        f1 = f1_score(true_labels, predictions, average='weighted')
        print(f"Test Loss: {running_loss/len(dataset_test)}, Accuracy: {running_acc/len(dataset_test)}")
        print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
        print(f"Class Counts: {class_counts}")
    print("Finished Testing")

    train_record = {"train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc}
    test_record = {"true_labels": true_labels, "predictions": predictions, "precision": precision, "recall": recall, "f1": f1, "class_counts": class_counts}


    with open(f'{save_path}/train_record/train_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(train_record, f)
    with open(f'{save_path}/test_record/test_record_lr{lr}_{rec_name}.pkl', 'wb') as f:
        pickle.dump(test_record, f)


Train Iter: 001/010  Loss: 3.6908: 100%|[32m██████████[0m| 1979/1979 [02:47<00:00, 11.81it/s]


Epoch 1, Loss: 3.439309311820451, Accuracy: 0.115499407816818


Valid Iter: 001/010  Loss: 2.6167: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 13.41it/s]


Validation Loss: 3.0437215614318847, Accuracy: 0.18444444444444444
Class Counts: [10, 0, 35, 4, 49, 15, 4, 7, 0, 0, 0, 0, 0, 8, 5, 11, 0, 1, 1, 2, 11, 10, 10, 6, 6, 10, 0, 3, 4, 4, 14, 2, 9, 29, 31, 7, 2, 16, 5, 18, 10, 6, 5, 23, 11, 5, 15, 3, 18, 5]


Train Iter: 002/010  Loss: 2.9658: 100%|[32m██████████[0m| 1979/1979 [02:47<00:00, 11.78it/s]


Epoch 2, Loss: 3.1501422381749324, Accuracy: 0.1650217133833399


Valid Iter: 002/010  Loss: 1.6219: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 12.71it/s]


Validation Loss: 2.8857523022757636, Accuracy: 0.21333333333333335
Class Counts: [5, 24, 3, 17, 25, 12, 6, 5, 3, 3, 0, 0, 1, 8, 2, 4, 9, 2, 4, 13, 7, 10, 8, 3, 2, 9, 4, 17, 3, 6, 6, 0, 0, 22, 28, 12, 3, 20, 6, 7, 9, 19, 16, 12, 13, 7, 15, 6, 17, 17]


Train Iter: 003/010  Loss: 2.7705: 100%|[32m██████████[0m| 1979/1979 [02:48<00:00, 11.75it/s]


Epoch 3, Loss: 3.024493179287404, Accuracy: 0.19174101855507303


Valid Iter: 003/010  Loss: 1.3442: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 13.11it/s]


Validation Loss: 2.7805364343855117, Accuracy: 0.2311111111111111
Class Counts: [12, 10, 9, 33, 22, 7, 5, 4, 3, 5, 0, 1, 3, 1, 13, 14, 7, 1, 7, 10, 8, 15, 5, 1, 5, 2, 5, 2, 17, 6, 4, 5, 7, 18, 20, 6, 5, 13, 14, 12, 9, 7, 10, 19, 18, 8, 11, 4, 20, 7]


Train Iter: 004/010  Loss: 2.6570: 100%|[32m██████████[0m| 1979/1979 [02:48<00:00, 11.75it/s]


Epoch 4, Loss: 2.920278646470622, Accuracy: 0.2126016581129096


Valid Iter: 004/010  Loss: 1.0967: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 12.84it/s]


Validation Loss: 2.696459993256463, Accuracy: 0.24
Class Counts: [11, 13, 6, 3, 30, 9, 6, 7, 5, 1, 4, 0, 2, 5, 5, 7, 9, 3, 4, 5, 15, 15, 19, 16, 1, 8, 3, 5, 3, 12, 5, 5, 5, 11, 14, 10, 12, 9, 19, 9, 13, 14, 4, 23, 6, 9, 23, 2, 17, 8]


Train Iter: 005/010  Loss: 3.2118: 100%|[32m██████████[0m| 1979/1979 [02:48<00:00, 11.74it/s]


Epoch 5, Loss: 2.837686190648461, Accuracy: 0.22896170548756414


Valid Iter: 005/010  Loss: 0.9505: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 13.65it/s]


Validation Loss: 2.594658277299669, Accuracy: 0.2511111111111111
Class Counts: [6, 2, 6, 5, 29, 16, 8, 3, 14, 4, 2, 1, 5, 5, 16, 17, 5, 14, 2, 16, 10, 13, 7, 5, 6, 8, 0, 9, 12, 4, 6, 8, 6, 11, 12, 4, 14, 13, 12, 11, 16, 9, 6, 16, 11, 8, 6, 2, 16, 13]


Train Iter: 006/010  Loss: 2.4597: 100%|[32m██████████[0m| 1979/1979 [02:48<00:00, 11.74it/s]


Epoch 6, Loss: 2.7595282038783915, Accuracy: 0.24546387682589815


Valid Iter: 006/010  Loss: 0.9308: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 12.76it/s]


Validation Loss: 2.5452331007851496, Accuracy: 0.28888888888888886
Class Counts: [6, 4, 9, 7, 19, 11, 17, 7, 12, 2, 4, 0, 4, 13, 14, 9, 3, 1, 4, 28, 13, 8, 13, 9, 5, 7, 3, 7, 9, 11, 6, 0, 6, 15, 10, 9, 11, 12, 10, 7, 13, 15, 1, 19, 11, 9, 13, 2, 17, 5]


Train Iter: 007/010  Loss: 2.9916: 100%|[32m██████████[0m| 1979/1979 [02:48<00:00, 11.75it/s]


Epoch 7, Loss: 2.690365862283507, Accuracy: 0.2643347808922227


Valid Iter: 007/010  Loss: 0.9856: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 12.54it/s]


Validation Loss: 2.508874473041958, Accuracy: 0.2866666666666667
Class Counts: [7, 7, 11, 9, 21, 10, 4, 6, 11, 2, 2, 0, 3, 7, 10, 6, 5, 3, 16, 15, 7, 6, 8, 11, 4, 5, 1, 11, 12, 13, 3, 8, 12, 11, 6, 11, 9, 6, 10, 10, 12, 20, 29, 12, 9, 12, 12, 2, 15, 8]


Train Iter: 008/010  Loss: 2.2139: 100%|[32m██████████[0m| 1979/1979 [02:48<00:00, 11.76it/s]


Epoch 8, Loss: 2.6268260446820295, Accuracy: 0.2747256217923411


Valid Iter: 008/010  Loss: 0.5007: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 12.70it/s]


Validation Loss: 2.4560040232870315, Accuracy: 0.31333333333333335
Class Counts: [12, 11, 7, 6, 19, 8, 7, 4, 15, 2, 4, 0, 3, 8, 15, 8, 5, 7, 8, 23, 9, 17, 10, 3, 7, 8, 8, 7, 4, 6, 3, 5, 11, 9, 5, 11, 7, 11, 10, 10, 19, 11, 3, 25, 13, 8, 10, 5, 16, 7]


Train Iter: 009/010  Loss: 2.1269: 100%|[32m██████████[0m| 1979/1979 [02:48<00:00, 11.75it/s]


Epoch 9, Loss: 2.576574374717354, Accuracy: 0.291006711409396


Valid Iter: 009/010  Loss: 0.6516: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 12.77it/s]


Validation Loss: 2.4316266952620613, Accuracy: 0.32
Class Counts: [5, 9, 6, 12, 10, 8, 5, 12, 14, 2, 2, 1, 4, 8, 16, 7, 9, 8, 6, 8, 12, 13, 7, 10, 9, 13, 4, 6, 10, 8, 4, 8, 9, 10, 6, 16, 8, 12, 10, 13, 15, 12, 0, 27, 12, 12, 9, 1, 14, 8]


Train Iter: 010/010  Loss: 2.4288: 100%|[32m██████████[0m| 1979/1979 [02:48<00:00, 11.77it/s]


Epoch 10, Loss: 2.52874028566935, Accuracy: 0.30185550730359256


Valid Iter: 010/010  Loss: 0.4356: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 12.63it/s]


Validation Loss: 2.4062459495332504, Accuracy: 0.31777777777777777
Class Counts: [9, 9, 9, 7, 16, 8, 8, 7, 7, 2, 1, 0, 4, 14, 24, 5, 7, 10, 4, 18, 6, 7, 17, 4, 10, 10, 5, 12, 7, 7, 3, 2, 10, 8, 7, 14, 8, 12, 15, 20, 12, 2, 13, 12, 5, 9, 8, 5, 14, 17]
Finished Training


Test Iter: 010/010  Loss: 1.1154: 100%|[31m██████████[0m| 15/15 [00:01<00:00, 12.51it/s]


Test Loss: 2.359922768274943, Accuracy: 0.35555555555555557
Precision: 0.38631091122499794, Recall: 0.35555555555555557, F1: 0.3412440434481002
Class Counts: [11, 8, 8, 9, 1, 17, 7, 12, 8, 2, 3, 6, 4, 13, 20, 12, 3, 9, 8, 16, 7, 3, 13, 6, 5, 11, 6, 12, 5, 5, 3, 2, 8, 15, 12, 8, 6, 5, 19, 12, 15, 6, 15, 9, 9, 13, 9, 12, 6, 16]
Finished Testing
