In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import shutil

src = "/content/drive/MyDrive/mute&cut_cqt.zip"  # 原文件路径
dst = "/content/mute&cut_cqt.zip"  # 目标路径

shutil.copy(src, dst)  # 复制文件（不包含元数据）
# shutil.copy2(src, dst)  # 复制文件（包含元数据）

'/content/mute&cut_cqt.zip'

In [3]:
import shutil

src = "/content/drive/MyDrive/mute&cut_rgba.zip"  # 原文件路径
dst = "/content/mute&cut_rgba.zip"  # 目标路径

shutil.copy(src, dst)  # 复制文件（不包含元数据）
# shutil.copy2(src, dst)  # 复制文件（包含元数据）

'/content/mute&cut_rgba.zip'

In [4]:
import zipfile

zip_path = "/content/mute&cut_rgba.zip"       # 待解压的 ZIP 文件路径
extract_to = "/content"   # 解压后存放文件的目标文件夹

# 打开 ZIP 文件并解压
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)

In [5]:
import zipfile

zip_path = "/content/mute&cut_cqt.zip"       # 待解压的 ZIP 文件路径
extract_to = "/content"   # 解压后存放文件的目标文件夹

# 打开 ZIP 文件并解压
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)

In [6]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import precision_score, recall_score, accuracy_score
import torchvision.transforms as transforms

In [70]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class GTCCMultiModalDataset(Dataset):
    def __init__(self, rgba_dir, cqt_dir, split='train', transform_resnet=None, transform_efficientnet=None):
        """
        自定义数据集类，确保 RGBA (4 通道) 和 CQT (1 通道) 数据的文件名和标签匹配。

        参数:
            rgba_dir (str): RGBA 图像数据集根目录，例如 "full_mixed_rgba"。
            cqt_dir (str): CQT 频谱图像数据集根目录，例如 "full_mixed_cqt"。
            split (str): 'train' 或 'test' 指定加载的子目录。
            transform_resnet: 适用于 RGBA 的图像转换 (ToTensor, Normalize)。
            transform_efficientnet: 适用于 CQT 的图像转换 (ToTensor, Normalize)。
        """
        self.transform_resnet = transform_resnet
        self.transform_efficientnet = transform_efficientnet

        rgba_split_dir = os.path.join(rgba_dir, split)
        cqt_split_dir = os.path.join(cqt_dir, split)

        # 获取所有类别
        self.classes = sorted([d for d in os.listdir(rgba_split_dir) if os.path.isdir(os.path.join(rgba_split_dir, d))])
        self.class_to_idx = {class_name: idx for idx, class_name in enumerate(self.classes)}

        self.data_pairs = []
        # 遍历 RGBA 数据集，并确保匹配 CQT 频谱数据
        for class_name in self.classes:
            rgba_class_dir = os.path.join(rgba_split_dir, class_name)
            cqt_class_dir = os.path.join(cqt_split_dir, class_name)

            for root, _, files in os.walk(rgba_class_dir):
                for fname in files:
                    if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                        rgba_img_path = os.path.join(root, fname)
                        cqt_img_path = os.path.join(cqt_class_dir, os.path.relpath(rgba_img_path, rgba_class_dir))  # 假设文件名相同

                        if os.path.exists(cqt_img_path):  # 确保 CQT 文件存在
                            label = self.class_to_idx[class_name]
                            self.data_pairs.append((rgba_img_path, cqt_img_path, label))
                        else:
                            print(f"CQT 图像不存在: {cqt_img_path}")
                    else:
                        print(f"RGBA 目录中发现无效图像文件: {os.path.join(root, fname)}")

        print(f"总共找到的数据对: {len(self.data_pairs)}")

    def __len__(self):
        return len(self.data_pairs)

    def __getitem__(self, idx):
        rgba_path, cqt_path, label = self.data_pairs[idx]

        # 加载 RGBA 图像
        rgba_image = Image.open(rgba_path).convert('RGBA')  # 确保 4 通道
        if self.transform_resnet:
            rgba_image = self.transform_resnet(rgba_image)

        # 加载 CQT 频谱图像
        cqt_image = Image.open(cqt_path).convert('L')  # 确保 1 通道
        if self.transform_efficientnet:
            cqt_image = self.transform_efficientnet(cqt_image)

        # 处理标签
        label = torch.tensor(label, dtype=torch.long)

        return rgba_image, cqt_image, label


# 示例使用
if __name__ == '__main__':
    from torchvision import transforms

    # 设置数据集路径
    rgba_dir = "/content/mute&cut_rgba"
    cqt_dir = "/content/mute&cut_cqt"

    # RGBA 图像的 transforms
    transform_resnet = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5, 0.5), (0.5, 0.5, 0.5, 0.5))  # RGBA 4 通道
    ])

    # CQT 频谱图像的 transforms
    transform_efficientnet = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # CQT 1 通道
    ])

    # 创建训练集
    train_dataset = GTCCMultiModalDataset(rgba_dir, cqt_dir, split='train',
                                          transform_resnet=transform_resnet,
                                          transform_efficientnet=transform_efficientnet)

    # 创建测试集
    test_dataset = GTCCMultiModalDataset(rgba_dir, cqt_dir, split='test',
                                         transform_resnet=transform_resnet,
                                         transform_efficientnet=transform_efficientnet)

    if len(train_dataset) == 0 or len(test_dataset) == 0:
        print("未找到数据。请检查您的数据集目录和文件名。")
    else:
        # 创建 DataLoader
        batch_size = 128
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

        print(f"训练集大小: {len(train_dataset)}")
        print(f"测试集大小: {len(test_dataset)}")

总共找到的数据对: 11326
总共找到的数据对: 2721
训练集大小: 11326
测试集大小: 2721


In [71]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import glob

class GTCCMultiModalDataset(Dataset):
    def __init__(self, rgba_dir, cqt_dir, split='train',
                 transform_resnet=None, transform_efficientnet=None):
        """
        自定义数据集类，支持在更深层级(三级目录及以下)查找图像
        """
        self.transform_resnet = transform_resnet
        self.transform_efficientnet = transform_efficientnet

        rgba_split_dir = os.path.join(rgba_dir, split)
        cqt_split_dir = os.path.join(cqt_dir, split)

        # 首先获取"类别" (依然假设类别就是二级目录名字)
        self.classes = sorted([
            d for d in os.listdir(rgba_split_dir)
            if os.path.isdir(os.path.join(rgba_split_dir, d))
        ])
        self.class_to_idx = {class_name: idx for idx, class_name in enumerate(self.classes)}

        # 开始递归地遍历
        self.data_pairs = []
        for class_name in self.classes:
            rgba_class_dir = os.path.join(rgba_split_dir, class_name)
            cqt_class_dir = os.path.join(cqt_split_dir, class_name)

            # **使用glob进行递归搜索**，这里示例只搜 '.png', '.jpg', '.jpeg'
            patterns = ['*.png', '*.jpg', '*.jpeg']
            rgba_file_list = []
            for p in patterns:
                rgba_file_list.extend(glob.glob(os.path.join(rgba_class_dir, '**', p), recursive=True))

            # 开始匹配这些RGBA文件对应的CQT文件
            for rgba_img_path in rgba_file_list:
                # 拿到最末级文件名
                fname = os.path.basename(rgba_img_path)
                # 构造对应的 CQT 文件路径 (假设同名)
                cqt_img_path = os.path.join(cqt_class_dir, fname)

                if os.path.exists(cqt_img_path):
                    label = self.class_to_idx[class_name]
                    self.data_pairs.append((rgba_img_path, cqt_img_path, label))

    def __len__(self):
        return len(self.data_pairs)

    def __getitem__(self, idx):
        rgba_path, cqt_path, label = self.data_pairs[idx]

        # 加载 RGBA
        rgba_image = Image.open(rgba_path).convert('RGBA')
        if self.transform_resnet:
            rgba_image = self.transform_resnet(rgba_image)

        # 加载 CQT
        cqt_image = Image.open(cqt_path).convert('L')
        if self.transform_efficientnet:
            cqt_image = self.transform_efficientnet(cqt_image)

        label = torch.tensor(label, dtype=torch.long)
        return rgba_image, cqt_image, label


In [72]:
import torch.nn.functional as F

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=4):
        super(ChannelAttention, self).__init__()

        # 평균 풀링과 최대 풀링을 사용하여 채널 중요도 계산
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        # 채널 수를 줄였다가 다시 확장하는 두 개의 1x1 컨볼루션을 사용
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        # 평균 풀링과 최대 풀링을 통해 각 채널에 대한 중요도를 계산
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))

        # 각 채널의 중요도를 합산하고 입력에 가중치를 적용
        out = avg_out + max_out
        out = out * x
        return F.relu(out)  # ReLU 활성화 함수 적용


In [73]:
class Stage1(nn.Module):
    def __init__(self):
        super(Stage1, self).__init__()

        # 첫 번째 컨볼루션 + 배치 정규화 + GELU 활성화 함수
        self.conv1 = nn.Sequential(
            nn.Conv2d(4, 64, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.GELU()
        )

        # 채널 어텐션
        self.cam1 = ChannelAttention(64)
        # 맥스풀링
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        # 첫 번째 컨볼루션 및 채널 어텐션 적용
        out = self.conv1(x)
        out = self.cam1(out)
        out = self.maxpool(out)
        return out

In [74]:
class AConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(AConvBlock, self).__init__()

        # 메인 브랜치: 여러 개의 7x7 컨볼루션과 배치 정규화
        self.main_branch = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=stride, padding=3),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=7, stride=1, padding=3),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels * 8, kernel_size=7, stride=1, padding=3),
            nn.BatchNorm2d(out_channels * 8)
        )

        # Shortcut: 입력과 출력의 차원이 다르면 1x1 컨볼루션을 사용해 맞춰줌
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels * 8:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * 8, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * 8)
            )

        # 채널 어텐션
        self.cam = ChannelAttention(out_channels * 8)

    def forward(self, x):
        # 메인 브랜치를 통과한 결과
        out = self.main_branch(x)
        # shortcut 연결 추가
        out += self.cam(self.shortcut(x))
        return F.relu(out)  # ReLU 활성화 함수 적용

In [75]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResNetFeatureExtractor(nn.Module):
    def __init__(self):
        super(ResNetFeatureExtractor, self).__init__()

        # 定义网络的各个阶段
        self.conv1_x = Stage1()  # 第一个卷积层和池化层
        self.conv2_x = AConvBlock(64, 32)  # 输入通道64，输出通道32
        self.conv3_x = AConvBlock(32, 64)  # 输入通道32，输出通道64
        self.conv4_x = AConvBlock(64, 128) # 输入通道64，输出通道128
        self.conv5_x = AConvBlock(128, 256) # 输入通道128，输出通道256
        self.conv6_x = AConvBlock(256, 512) # 输入通道256，输出通道512

        # 全连接层，输出维度为512
        self.fc = nn.Linear(512, 512)

        # 自适应平均池化和PReLU激活函数
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.prelu = nn.PReLU()

    def forward(self, x):
        # 检查输入张量的维度
        if len(x.size()) != 4:
            raise ValueError(f"Expected input tensor to have 4 dimensions, got {len(x.size())}")

        # 前向传播过程
        out = self.conv1_x(x)
        out = self.conv2_x(out)
        out = self.conv3_x(out)
        out = self.conv4_x(out)
        out = self.conv5_x(out)
        out = self.conv6_x(out)

        # 自适应平均池化后展平
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)

        # 全连接层，输出512维特征
        out = self.fc(out)
        out = self.prelu(out)

        return out  # 输出512维特征

# Stage1的定义，包括第一个卷积层、批归一化层、ReLU激活函数和最大池化层
class Stage1(nn.Module):
    def __init__(self):
        super(Stage1, self).__init__()
        self.conv = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)  # 修改输入通道数为4
        self.bn = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.maxpool(x)
        return x

# AConvBlock的定义，包括一个卷积层、批归一化层和ReLU激活函数
class AConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

In [76]:
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.global_avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class CBAM(nn.Module):
    def __init__(self, channels, reduction=16):
        super(CBAM, self).__init__()
        self.channel_attention = SEBlock(channels, reduction)
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.channel_attention(x)
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        y = torch.cat([avg_out, max_out], dim=1)
        y = self.spatial_attention(y)
        return x * y


In [77]:
import torch
import torch.nn as nn
import torchvision.models as models

class MultiModalModel(nn.Module):
    def __init__(self, num_classes=4):
        super(MultiModalModel, self).__init__()

        # **1. ResNet 处理 RGBA Mel 频谱**
        self.resnet = ResNetFeatureExtractor()  # **只接受 RGBA 4 通道输入**
        self.resnet_cbam = CBAM(512)  # **CBAM 适配 ResNet 512 维特征**

        # **2. EfficientNet-B5 处理 CQT 频谱**
        self.efficientnet = models.efficientnet_b5(pretrained=True)
        self.efficientnet.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)  # **只接受 CQT 1 通道输入**
        self.efficientnet.classifier = nn.Identity()  # **移除分类层**
        self.efficientnet_cbam = CBAM(2048)  # **CBAM 适配 EfficientNet 2048 维特征**

        # **3. 降维 & 融合层**
        self.conv1_resnet = nn.Conv2d(512, 512, kernel_size=1)  # **ResNet 降维**
        self.conv1_efficientnet = nn.Conv2d(2048, 512, kernel_size=1)  # **EfficientNet 降维**

        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))  # **全局池化**

        self.fc = nn.Sequential(
            nn.Linear(512 + 512, 512),  # **最终特征融合**
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x_resnet, x_cqt):
        """
        x_resnet: RGBA (4通道) 只输入 ResNet
        x_cqt: CQT (1通道) 只输入 EfficientNet
        """
        # **ResNet 处理 RGBA**
        resnet_feat = self.resnet(x_resnet)  # (batch, 512, 7, 7)
        resnet_feat = self.resnet_cbam(resnet_feat)
        resnet_feat = self.conv1_resnet(resnet_feat)

        # **EfficientNet 处理 CQT**
        efficientnet_feat = self.efficientnet(x_cqt)  # (batch, 2048, 7, 7)
        efficientnet_feat = self.efficientnet_cbam(efficientnet_feat)
        efficientnet_feat = self.conv1_efficientnet(efficientnet_feat)

        # **全局平均池化**
        resnet_feat = self.global_pool(resnet_feat).view(resnet_feat.size(0), -1)  # (batch, 512)
        efficientnet_feat = self.global_pool(efficientnet_feat).view(efficientnet_feat.size(0), -1)  # (batch, 512)

        # **特征融合**
        fusion = torch.cat((resnet_feat, efficientnet_feat), dim=1)  # (batch, 1024)

        # **最终分类**
        output = self.fc(fusion)
        return output


In [78]:
# 设置计算设备（GPU/CPU）
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = MultiModalModel().to(device)  # 모델을 지정된 장치로 이동

In [79]:
class CenterLoss(nn.Module):
    """
    Center loss.

    Reference:
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.

    Args:
        num_classes (int): number of classes.
        feat_dim (int): feature dimension.
    """
    def __init__(self, num_classes=10, feat_dim=256, use_gpu=True):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.use_gpu = use_gpu

        # 클래스 중심 초기화
        if self.use_gpu:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
        else:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))

    def forward(self, x, labels):
        """
        Args:
            x: feature matrix with shape (batch_size, feat_dim).
            labels: ground truth labels with shape (batch_size).
        """
        batch_size = x.size(0)

        # 각 클래스 중심과의 거리 계산
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat.addmm_(1, -2, x, self.centers.t())

        # 레이블과 일치하는 위치 마스크 생성
        classes = torch.arange(self.num_classes).long()
        if self.use_gpu:
            classes = classes.cuda()
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, self.num_classes))

        # 거리 계산 후 손실 값 구하기
        dist = distmat * mask.float()
        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size

        return loss


In [80]:
from torch.optim import Adam, SGD
import torch.nn as nn

# **修改 CenterLoss 维度为 512**
criterion = nn.CrossEntropyLoss().to(device)  # 交叉熵损失 (分类任务)
center = CenterLoss(4, 1024).to(device)  # **修改 feature dimension 24 -> 512**

# **优化器**
opti1 = Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)  # **模型参数**
opti2 = SGD(center.parameters(), lr=0.2)  # **降低 CenterLoss 学习率**

# **学习率调度器**
scheduler1 = torch.optim.lr_scheduler.StepLR(opti1, step_size=10, gamma=0.5)  # **每 10 轮 lr * 0.5**
scheduler2 = torch.optim.lr_scheduler.StepLR(opti2, step_size=10, gamma=0.5)

In [81]:
from tqdm import tqdm

def train(model, dataloader, criterion, data_len, opti1, opti2):
    correct = 0
    losses = 0

    model.train()  # 设置模型为训练模式
    for rgba_data, cqt_data, target in tqdm(dataloader):  # RGBA & CQT 数据
        # 将数据移动到 GPU 或 CPU
        rgba_data = rgba_data.to(device)
        cqt_data = cqt_data.to(device)
        target = target.to(device)

        # **模型前向传播**
        cen, output = model(rgba_data, cqt_data)  # **使用两个输入**

        # **计算损失**
        loss1 = criterion(output, target)  # 交叉熵损失
        loss2 = center(cen, target)  # CenterLoss
        loss = loss1 + loss2  # **融合损失**

        # **优化器梯度清零**
        opti1.zero_grad()
        opti2.zero_grad()

        # **反向传播**
        loss.backward()
        opti1.step()
        opti2.step()

        # **计算准确率**
        pred = output.max(1, keepdim=True)[1]  # 取最大值索引
        correct += pred.eq(target.view_as(pred)).sum().item()  # 计算正确预测的数量
        losses += loss.item()


    scheduler1.step()
    scheduler2.step()

    # **返回准确率和平均损失**
    return 100 * correct / data_len, losses / data_len


In [82]:
def evaluate(model, dataloader, criterion, data_len):
    correct = 0
    total_loss = 0

    model.eval()  # **设置模型为评估模式**
    with torch.no_grad():  # **评估时不计算梯度**
        for rgba_data, cqt_data, target in dataloader:  # **获取 RGBA & CQT 数据**
            rgba_data = rgba_data.to(device)
            cqt_data = cqt_data.to(device)
            target = target.to(device)

            # **前向传播**
            _, output = model(rgba_data, cqt_data)

            # **计算损失**
            loss = criterion(output, target)
            total_loss += loss.item()

            # **计算准确率**
            pred = output.max(1, keepdim=True)[1]  # 取最大值索引
            correct += pred.eq(target.view_as(pred)).sum().item()

    # **计算最终评估结果**
    acc = 100. * correct / data_len

    return acc  # **返回准确率 & 平均损失**


In [83]:
epoch = 100

train_accuracies = []
val_accuracies = []

for i in range(epoch):
    # Training the model
    train_acc, train_loss = train(model, train_loader, criterion, len(train_loader.dataset), opti1, opti2)

    # Evaluating the model on validation data
    val_acc = evaluate(model, test_loader, criterion, len(test_loader.dataset))

    # Uncomment the line below if you want to evaluate on test data
    # test_acc = evaluate(model, test_dataloader, criterion, len(test_dataloader.dataset))

    # Storing the accuracies
    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)

    # Printing the results for the current epoch
    print(f"[Epoch: {i+1}], [Validation Acc: {val_acc:.4f}]")
    print(f"train_acc: {train_acc}, train_loss: {train_loss}")

  0%|          | 0/89 [00:00<?, ?it/s]


ValueError: not enough values to unpack (expected 4, got 2)

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score
import torch
import numpy as np

# Plotting training and validation accuracies
plt.figure(figsize=(10, 5))
plt.plot(train_accuracies, label='Training Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Training and Validation Accuracies')
plt.show()

# Setting up the device and loading the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet().to(device)
model.load_state_dict(torch.load('./deepship7.pt'))

# Function to evaluate precision, recall, and F1-score
def evaluate_metrics(model, dataloader):
    model.eval()  # Set the model to evaluation mode
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for data, target in dataloader:
            data = data.to(device)
            target = target.to(device)

            _, output = model(data)
            pred = output.argmax(dim=1)  # Get the class with the highest score

            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

    # Calculate precision, recall, and F1-score for each class
    precision = precision_score(all_targets, all_preds, average=None)
    recall = recall_score(all_targets, all_preds, average=None)
    f1 = f1_score(all_targets, all_preds, average=None)

    # Calculate average precision, recall, and F1-score
    avg_precision = precision.mean()
    avg_recall = recall.mean()
    avg_f1 = f1.mean()

    return precision, recall, f1, avg_precision, avg_recall, avg_f1

# Evaluate metrics on the test dataset
class_precision, class_recall, class_f1, avg_precision, avg_recall, avg_f1 = evaluate_metrics(model, test_dataloader)

# Print performance for each class
for i in range(len(class_precision)):
    print(f"Class {i} - Precision: {class_precision[i]:.4f}, Recall: {class_recall[i]:.4f}, F1-score: {class_f1[i]:.4f}")

# Print average performance across all classes
print(f"Avg Precision: {avg_precision:.4f}, Avg Recall: {avg_recall:.4f}, Avg F1-score: {avg_f1:.4f}")
