# 导入module

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from tqdm import tqdm

device = torch.device('cuda')

# 1. 自监督学习

## 1.1 对比学习

In [12]:
class SimCLRTransform:
    def __init__(self):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        return self.transform(x), self.transform(x)


batch_size = 256  
dataset = datasets.STL10(root='autodl-tmp/data/data', split='unlabeled', transform=SimCLRTransform(), download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# 定义ResNet-18模型
class ResNetSimCLR(nn.Module):
    def __init__(self, base_model, out_dim):
        super(ResNetSimCLR, self).__init__()
        self.encoder = nn.Sequential(*list(base_model.children())[:-1])
        self.fc = nn.Sequential(
            nn.Linear(base_model.fc.in_features, 2048),
            nn.ReLU(),
            nn.Linear(2048, out_dim)
        )

    def forward(self, x):
        h = self.encoder(x)
        h = h.view(h.size(0), -1)
        z = self.fc(h)
        return h, z


base_model = resnet18(weights=None)  
model = ResNetSimCLR(base_model, out_dim=128).to(device)


class NTXentLoss(nn.Module):
    def __init__(self, batch_size, temperature):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.mask = self.mask_correlated_samples(batch_size)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")

    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size  
        mask = torch.ones((N, N), dtype=bool)  
        mask = mask.fill_diagonal_(0)  

        for i in range(batch_size):
            mask[i, batch_size + i] = 0  
            mask[batch_size + i, i] = 0  

        return mask

    def forward(self, z_i, z_j):
        batch_size = z_i.size(0) 
        N = 2 * batch_size
        z = torch.cat((z_i, z_j), dim=0)  
        sim = torch.matmul(z, z.T) / self.temperature  

        
        assert sim.shape == (N, N), f"Sim matrix shape is {sim.shape}, expected ({N}, {N})"

        sim_i_j = torch.diag(sim, batch_size)  
        sim_j_i = torch.diag(sim, -batch_size)

        
        assert sim_i_j.shape == (batch_size,), f"sim_i_j shape is {sim_i_j.shape}, expected ({batch_size},)"
        assert sim_j_i.shape == (batch_size,), f"sim_j_i shape is {sim_j_i.shape}, expected ({batch_size},)"

        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
        mask = self.mask_correlated_samples(batch_size)  
        negative_samples = sim[mask].reshape(N, -1)

        
        assert negative_samples.shape == (N, N-2), f"negative_samples shape is {negative_samples.shape}, expected ({N}, {N-2})"

        labels = torch.zeros(N).to(positive_samples.device).long()
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        return loss / N


criterion = NTXentLoss(batch_size=batch_size, temperature=0.5).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3) 


def train_simclr(model, dataloader, criterion, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        for images, _ in progress_bar:
            images[0] = images[0].to(device)
            images[1] = images[1].to(device)
            optimizer.zero_grad()

            # 将批次中的每个图像增强两次
            z_i, _ = model(images[0])
            z_j, _ = model(images[1])

            # 计算对比损失
            loss = criterion(z_i, z_j)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())
        
        
        print(f"Epoch [{epoch+1}/{epochs}] completed.")
        print(f"Allocated memory: {torch.cuda.memory_allocated(device)} bytes")
        print(f"Reserved memory: {torch.cuda.memory_reserved(device)} bytes")
        
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss/len(dataloader):.4f}')


train_simclr(model, dataloader, criterion, optimizer, epochs=50)
torch.save(model.state_dict(), 'resnet_simclr.pth')


Files already downloaded and verified


Epoch 1/50: 100%|██████████| 391/391 [03:30<00:00,  1.86it/s, loss=7.32] 


Epoch [1/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [1/50], Loss: 17.0222


Epoch 2/50: 100%|██████████| 391/391 [04:21<00:00,  1.49it/s, loss=5.56]


Epoch [2/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [2/50], Loss: 6.1618


Epoch 3/50: 100%|██████████| 391/391 [03:45<00:00,  1.73it/s, loss=5.42]


Epoch [3/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [3/50], Loss: 6.0214


Epoch 5/50: 100%|██████████| 391/391 [03:26<00:00,  1.90it/s, loss=5.14]


Epoch [5/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [5/50], Loss: 5.5735


Epoch 6/50: 100%|██████████| 391/391 [03:27<00:00,  1.89it/s, loss=4.68]


Epoch [6/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [6/50], Loss: 5.1602


Epoch 7/50: 100%|██████████| 391/391 [04:08<00:00,  1.57it/s, loss=3.82]


Epoch [7/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [7/50], Loss: 4.6693


Epoch 8/50: 100%|██████████| 391/391 [03:44<00:00,  1.74it/s, loss=3.5] 


Epoch [8/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [8/50], Loss: 4.1166


Epoch 9/50: 100%|██████████| 391/391 [03:29<00:00,  1.87it/s, loss=3.15]


Epoch [9/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [9/50], Loss: 3.4969


Epoch 10/50: 100%|██████████| 391/391 [03:33<00:00,  1.83it/s, loss=2.7] 


Epoch [10/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [10/50], Loss: 3.1380


Epoch 11/50: 100%|██████████| 391/391 [03:30<00:00,  1.86it/s, loss=1.86]


Epoch [11/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [11/50], Loss: 2.5202


Epoch 12/50: 100%|██████████| 391/391 [03:34<00:00,  1.83it/s, loss=1.61]


Epoch [12/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [12/50], Loss: 2.1678


Epoch 13/50: 100%|██████████| 391/391 [03:21<00:00,  1.94it/s, loss=1.32]


Epoch [13/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [13/50], Loss: 2.0278


Epoch 14/50: 100%|██████████| 391/391 [03:26<00:00,  1.90it/s, loss=1.59]


Epoch [14/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [14/50], Loss: 1.7244


Epoch 15/50: 100%|██████████| 391/391 [03:31<00:00,  1.85it/s, loss=1.43]


Epoch [15/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [15/50], Loss: 1.5610


Epoch 16/50: 100%|██████████| 391/391 [03:39<00:00,  1.78it/s, loss=1.34]


Epoch [16/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [16/50], Loss: 1.4670


Epoch 17/50: 100%|██████████| 391/391 [03:34<00:00,  1.82it/s, loss=1.37] 


Epoch [17/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [17/50], Loss: 1.7129


Epoch 18/50: 100%|██████████| 391/391 [03:34<00:00,  1.82it/s, loss=1.31]


Epoch [18/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [18/50], Loss: 1.4193


Epoch 19/50: 100%|██████████| 391/391 [03:36<00:00,  1.81it/s, loss=0.992]


Epoch [19/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [19/50], Loss: 1.2724


Epoch 20/50: 100%|██████████| 391/391 [03:27<00:00,  1.89it/s, loss=0.889]


Epoch [20/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [20/50], Loss: 1.1617


Epoch 21/50: 100%|██████████| 391/391 [03:30<00:00,  1.86it/s, loss=0.666]


Epoch [21/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [21/50], Loss: 1.0969


Epoch 22/50: 100%|██████████| 391/391 [03:30<00:00,  1.86it/s, loss=1.02] 


Epoch [22/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [22/50], Loss: 1.0529


Epoch 23/50: 100%|██████████| 391/391 [03:32<00:00,  1.84it/s, loss=0.68] 


Epoch [23/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [23/50], Loss: 1.0211


Epoch 24/50: 100%|██████████| 391/391 [03:37<00:00,  1.80it/s, loss=0.823]


Epoch [24/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [24/50], Loss: 0.9640


Epoch 25/50: 100%|██████████| 391/391 [03:22<00:00,  1.93it/s, loss=0.542]


Epoch [25/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [25/50], Loss: 0.9295


Epoch 26/50: 100%|██████████| 391/391 [03:23<00:00,  1.93it/s, loss=0.821]


Epoch [26/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [26/50], Loss: 0.9501


Epoch 28/50: 100%|██████████| 391/391 [03:32<00:00,  1.84it/s, loss=0.748]


Epoch [28/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [28/50], Loss: 0.8297


Epoch 29/50: 100%|██████████| 391/391 [03:38<00:00,  1.79it/s, loss=0.447]


Epoch [29/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [29/50], Loss: 0.8120


Epoch 30/50: 100%|██████████| 391/391 [03:27<00:00,  1.88it/s, loss=0.926]


Epoch [30/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [30/50], Loss: 0.7863


Epoch 31/50: 100%|██████████| 391/391 [03:42<00:00,  1.75it/s, loss=0.486]


Epoch [31/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [31/50], Loss: 0.7465


Epoch 32/50:   5%|▌         | 20/391 [00:15<03:09,  1.95it/s, loss=0.73] IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 33/50: 100%|██████████| 391/391 [03:48<00:00,  1.71it/s, loss=0.607]


Epoch [33/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [33/50], Loss: 0.7831


Epoch 34/50: 100%|██████████| 391/391 [04:06<00:00,  1.59it/s, loss=0.576]


Epoch [34/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [34/50], Loss: 0.7203


Epoch 35/50: 100%|██████████| 391/391 [04:01<00:00,  1.62it/s, loss=0.565]


Epoch [35/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [35/50], Loss: 0.6868


Epoch 36/50: 100%|██████████| 391/391 [03:55<00:00,  1.66it/s, loss=0.662]


Epoch [36/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [36/50], Loss: 0.6573


Epoch 37/50: 100%|██████████| 391/391 [04:00<00:00,  1.62it/s, loss=0.479]


Epoch [37/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [37/50], Loss: 0.6505


Epoch 38/50: 100%|██████████| 391/391 [04:02<00:00,  1.61it/s, loss=0.656]


Epoch [38/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [38/50], Loss: 0.6419


Epoch 39/50: 100%|██████████| 391/391 [03:56<00:00,  1.65it/s, loss=0.506]


Epoch [39/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [39/50], Loss: 0.6185


Epoch 40/50: 100%|██████████| 391/391 [03:59<00:00,  1.63it/s, loss=0.485]


Epoch [40/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [40/50], Loss: 0.6097


Epoch 41/50: 100%|██████████| 391/391 [03:52<00:00,  1.68it/s, loss=0.582]


Epoch [41/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [41/50], Loss: 0.5957


Epoch 42/50: 100%|██████████| 391/391 [03:54<00:00,  1.66it/s, loss=0.502]


Epoch [42/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [42/50], Loss: 0.5863


Epoch 43/50: 100%|██████████| 391/391 [04:04<00:00,  1.60it/s, loss=0.617]


Epoch [43/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [43/50], Loss: 0.5672


Epoch 44/50: 100%|██████████| 391/391 [04:02<00:00,  1.61it/s, loss=0.403]


Epoch [44/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [44/50], Loss: 0.5673


Epoch 45/50: 100%|██████████| 391/391 [03:54<00:00,  1.67it/s, loss=0.449]


Epoch [45/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [45/50], Loss: 0.5538


Epoch 46/50: 100%|██████████| 391/391 [03:53<00:00,  1.68it/s, loss=0.627]


Epoch [46/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [46/50], Loss: 0.5392


Epoch 47/50: 100%|██████████| 391/391 [04:10<00:00,  1.56it/s, loss=0.499]


Epoch [47/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [47/50], Loss: 0.5272


Epoch 48/50: 100%|██████████| 391/391 [04:01<00:00,  1.62it/s, loss=0.512]


Epoch [48/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [48/50], Loss: 0.5230


Epoch 49/50: 100%|██████████| 391/391 [04:03<00:00,  1.60it/s, loss=0.397]


Epoch [49/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [49/50], Loss: 0.5225


Epoch 50/50: 100%|██████████| 391/391 [03:56<00:00,  1.65it/s, loss=0.403]


Epoch [50/50] completed.
Allocated memory: 1982695936 bytes
Reserved memory: 18452840448 bytes
Epoch [50/50], Loss: 0.5097


## 1.2 监督学习训练输出层

In [14]:
# CIFAR-100 数据集加载
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.CIFAR100(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR100(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4)


class LinearClassifier(nn.Module):
    def __init__(self, encoder, num_classes):
        super(LinearClassifier, self).__init__()
        self.encoder = encoder
        self.fc = nn.Linear(512, num_classes)  # 512是ResNet-18编码器输出的特征维度

    def forward(self, x):
        with torch.no_grad():
            features = self.encoder(x)
            features = features.view(features.size(0), -1)
        out = self.fc(features)
        return out

# 加载预训练模型权重
model = ResNetSimCLR(resnet18(weights=None), out_dim=128).to(device)
model.load_state_dict(torch.load('resnet_simclr.pth'))

# 冻结编码器部分，只训练线性分类器
for param in model.encoder.parameters():
    param.requires_grad = False

linear_classifier = LinearClassifier(model.encoder, num_classes=100).to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(linear_classifier.fc.parameters(), lr=1e-3)

# 训练线性分类器
def train_linear_classifier(model, dataloader, criterion, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        correct = 0
        total = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            progress_bar.set_postfix(loss=loss.item(), accuracy=100.*correct/total)
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss/len(dataloader):.4f}, Accuracy: {100.*correct/total:.2f}%')

train_linear_classifier(linear_classifier, train_loader, criterion, optimizer, epochs=50)


Files already downloaded and verified
Files already downloaded and verified


Epoch 1/50: 100%|██████████| 196/196 [00:24<00:00,  8.11it/s, accuracy=21.8, loss=3.52]


Epoch [1/50], Loss: 4.0219, Accuracy: 21.77%


Epoch 2/50: 100%|██████████| 196/196 [00:23<00:00,  8.35it/s, accuracy=35, loss=3.02]  


Epoch [2/50], Loss: 3.2560, Accuracy: 35.00%


Epoch 3/50: 100%|██████████| 196/196 [00:23<00:00,  8.33it/s, accuracy=38.5, loss=2.67]


Epoch [3/50], Loss: 2.8684, Accuracy: 38.49%


Epoch 4/50: 100%|██████████| 196/196 [00:23<00:00,  8.37it/s, accuracy=40.9, loss=2.54]


Epoch [4/50], Loss: 2.6456, Accuracy: 40.93%


Epoch 5/50: 100%|██████████| 196/196 [00:23<00:00,  8.42it/s, accuracy=42.4, loss=2.46]


Epoch [5/50], Loss: 2.5001, Accuracy: 42.41%


Epoch 6/50: 100%|██████████| 196/196 [00:23<00:00,  8.39it/s, accuracy=43.7, loss=2.34]


Epoch [6/50], Loss: 2.3970, Accuracy: 43.72%


Epoch 7/50: 100%|██████████| 196/196 [00:23<00:00,  8.23it/s, accuracy=44.6, loss=2.34]


Epoch [7/50], Loss: 2.3169, Accuracy: 44.63%


Epoch 8/50: 100%|██████████| 196/196 [00:23<00:00,  8.37it/s, accuracy=45.5, loss=2.33]


Epoch [8/50], Loss: 2.2546, Accuracy: 45.48%


Epoch 9/50: 100%|██████████| 196/196 [00:23<00:00,  8.33it/s, accuracy=46.3, loss=2.22]


Epoch [9/50], Loss: 2.2024, Accuracy: 46.27%


Epoch 10/50: 100%|██████████| 196/196 [00:23<00:00,  8.36it/s, accuracy=47.1, loss=2.17]


Epoch [10/50], Loss: 2.1596, Accuracy: 47.10%


Epoch 11/50: 100%|██████████| 196/196 [00:24<00:00,  7.98it/s, accuracy=47.7, loss=2.21]


Epoch [11/50], Loss: 2.1214, Accuracy: 47.73%


Epoch 12/50: 100%|██████████| 196/196 [00:24<00:00,  7.93it/s, accuracy=48.2, loss=1.86]


Epoch [12/50], Loss: 2.0880, Accuracy: 48.20%


Epoch 13/50: 100%|██████████| 196/196 [00:23<00:00,  8.36it/s, accuracy=48.8, loss=2.03]


Epoch [13/50], Loss: 2.0582, Accuracy: 48.84%


Epoch 14/50: 100%|██████████| 196/196 [00:23<00:00,  8.37it/s, accuracy=49.2, loss=2.21]


Epoch [14/50], Loss: 2.0328, Accuracy: 49.23%


Epoch 15/50: 100%|██████████| 196/196 [00:23<00:00,  8.33it/s, accuracy=49.8, loss=2.13]


Epoch [15/50], Loss: 2.0093, Accuracy: 49.76%


Epoch 16/50: 100%|██████████| 196/196 [00:23<00:00,  8.38it/s, accuracy=50.1, loss=2.04]


Epoch [16/50], Loss: 1.9873, Accuracy: 50.10%


Epoch 17/50: 100%|██████████| 196/196 [00:21<00:00,  9.03it/s, accuracy=50.4, loss=1.9] 


Epoch [17/50], Loss: 1.9661, Accuracy: 50.43%


Epoch 18/50: 100%|██████████| 196/196 [00:23<00:00,  8.33it/s, accuracy=50.7, loss=2]   


Epoch [18/50], Loss: 1.9495, Accuracy: 50.67%


Epoch 19/50: 100%|██████████| 196/196 [00:23<00:00,  8.43it/s, accuracy=51.1, loss=2.16]


Epoch [19/50], Loss: 1.9318, Accuracy: 51.13%


Epoch 20/50: 100%|██████████| 196/196 [00:23<00:00,  8.45it/s, accuracy=51.5, loss=2.05]


Epoch [20/50], Loss: 1.9148, Accuracy: 51.47%


Epoch 21/50: 100%|██████████| 196/196 [00:23<00:00,  8.36it/s, accuracy=51.9, loss=2.03]


Epoch [21/50], Loss: 1.9003, Accuracy: 51.88%


Epoch 22/50: 100%|██████████| 196/196 [00:23<00:00,  8.32it/s, accuracy=52.1, loss=1.78]


Epoch [22/50], Loss: 1.8857, Accuracy: 52.07%


Epoch 23/50: 100%|██████████| 196/196 [00:23<00:00,  8.20it/s, accuracy=52.4, loss=1.61]


Epoch [23/50], Loss: 1.8709, Accuracy: 52.40%


Epoch 24/50: 100%|██████████| 196/196 [00:25<00:00,  7.69it/s, accuracy=52.6, loss=1.92]


Epoch [24/50], Loss: 1.8600, Accuracy: 52.57%


Epoch 25/50: 100%|██████████| 196/196 [00:24<00:00,  8.04it/s, accuracy=52.7, loss=1.94]


Epoch [25/50], Loss: 1.8471, Accuracy: 52.74%


Epoch 26/50: 100%|██████████| 196/196 [00:23<00:00,  8.26it/s, accuracy=53, loss=1.81]  


Epoch [26/50], Loss: 1.8362, Accuracy: 53.03%


Epoch 27/50: 100%|██████████| 196/196 [00:23<00:00,  8.38it/s, accuracy=53.3, loss=1.92]


Epoch [27/50], Loss: 1.8246, Accuracy: 53.30%


Epoch 28/50: 100%|██████████| 196/196 [00:23<00:00,  8.25it/s, accuracy=53.5, loss=1.69]


Epoch [28/50], Loss: 1.8144, Accuracy: 53.49%


Epoch 29/50: 100%|██████████| 196/196 [00:23<00:00,  8.33it/s, accuracy=53.9, loss=1.86]


Epoch [29/50], Loss: 1.8042, Accuracy: 53.92%


Epoch 30/50: 100%|██████████| 196/196 [00:23<00:00,  8.37it/s, accuracy=53.9, loss=1.63]


Epoch [30/50], Loss: 1.7963, Accuracy: 53.94%


Epoch 31/50: 100%|██████████| 196/196 [00:23<00:00,  8.37it/s, accuracy=54.2, loss=1.79]


Epoch [31/50], Loss: 1.7857, Accuracy: 54.21%


Epoch 32/50: 100%|██████████| 196/196 [00:23<00:00,  8.41it/s, accuracy=54.4, loss=1.59]


Epoch [32/50], Loss: 1.7765, Accuracy: 54.38%


Epoch 33/50: 100%|██████████| 196/196 [00:23<00:00,  8.38it/s, accuracy=54.5, loss=1.85]


Epoch [33/50], Loss: 1.7685, Accuracy: 54.48%


Epoch 34/50: 100%|██████████| 196/196 [00:23<00:00,  8.40it/s, accuracy=54.7, loss=1.73]


Epoch [34/50], Loss: 1.7607, Accuracy: 54.69%


Epoch 35/50: 100%|██████████| 196/196 [00:23<00:00,  8.37it/s, accuracy=54.9, loss=1.8] 


Epoch [35/50], Loss: 1.7537, Accuracy: 54.90%


Epoch 36/50: 100%|██████████| 196/196 [00:23<00:00,  8.35it/s, accuracy=55.1, loss=1.82]


Epoch [36/50], Loss: 1.7465, Accuracy: 55.08%


Epoch 37/50: 100%|██████████| 196/196 [00:23<00:00,  8.37it/s, accuracy=55.3, loss=1.99]


Epoch [37/50], Loss: 1.7370, Accuracy: 55.27%


Epoch 38/50: 100%|██████████| 196/196 [00:23<00:00,  8.39it/s, accuracy=55.4, loss=1.56]


Epoch [38/50], Loss: 1.7300, Accuracy: 55.43%


Epoch 39/50: 100%|██████████| 196/196 [00:23<00:00,  8.46it/s, accuracy=55.4, loss=1.43]


Epoch [39/50], Loss: 1.7224, Accuracy: 55.43%


Epoch 40/50: 100%|██████████| 196/196 [00:24<00:00,  8.08it/s, accuracy=55.7, loss=1.62]


Epoch [40/50], Loss: 1.7151, Accuracy: 55.75%


Epoch 41/50: 100%|██████████| 196/196 [00:23<00:00,  8.39it/s, accuracy=55.9, loss=1.85]


Epoch [41/50], Loss: 1.7104, Accuracy: 55.86%


Epoch 42/50: 100%|██████████| 196/196 [00:23<00:00,  8.27it/s, accuracy=56, loss=1.52]  


Epoch [42/50], Loss: 1.7036, Accuracy: 55.98%


Epoch 43/50: 100%|██████████| 196/196 [00:22<00:00,  8.84it/s, accuracy=56, loss=1.61]  


Epoch [43/50], Loss: 1.6977, Accuracy: 55.97%


Epoch 44/50: 100%|██████████| 196/196 [00:22<00:00,  8.56it/s, accuracy=56.2, loss=1.71]


Epoch [44/50], Loss: 1.6924, Accuracy: 56.25%


Epoch 45/50: 100%|██████████| 196/196 [00:23<00:00,  8.47it/s, accuracy=56.2, loss=1.58]


Epoch [45/50], Loss: 1.6854, Accuracy: 56.24%


Epoch 46/50: 100%|██████████| 196/196 [00:23<00:00,  8.45it/s, accuracy=56.5, loss=1.78]


Epoch [46/50], Loss: 1.6805, Accuracy: 56.53%


Epoch 47/50: 100%|██████████| 196/196 [00:23<00:00,  8.39it/s, accuracy=56.5, loss=1.43]


Epoch [47/50], Loss: 1.6757, Accuracy: 56.47%


Epoch 48/50: 100%|██████████| 196/196 [00:23<00:00,  8.43it/s, accuracy=56.6, loss=1.68]


Epoch [48/50], Loss: 1.6710, Accuracy: 56.58%


Epoch 49/50: 100%|██████████| 196/196 [00:23<00:00,  8.23it/s, accuracy=56.9, loss=1.87]


Epoch [49/50], Loss: 1.6651, Accuracy: 56.92%


Epoch 50/50: 100%|██████████| 196/196 [00:23<00:00,  8.47it/s, accuracy=56.9, loss=1.65]

Epoch [50/50], Loss: 1.6601, Accuracy: 56.90%





# 2. 监督学习

## 2.1 加载预训练模型微调

In [16]:
# 加载在 ImageNet 上预训练的 ResNet-18 模型
pretrained_model = resnet18(pretrained=True).to(device)

# 冻结编码器部分，只训练线性分类器
for param in pretrained_model.parameters():
    param.requires_grad = False

# 替换全连接层，调整输出维度为 100
num_features = pretrained_model.fc.in_features
pretrained_model.fc = nn.Linear(num_features, 100).to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(pretrained_model.fc.parameters(), lr=1e-3)



Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [02:58<00:00, 263kB/s]
Epoch 1/50: 100%|██████████| 196/196 [00:23<00:00,  8.23it/s, accuracy=37.5, loss=2.14]


Epoch [1/50], Loss: 2.8293, Accuracy: 37.46%


Epoch 2/50: 100%|██████████| 196/196 [00:24<00:00,  8.02it/s, accuracy=54.8, loss=1.98]


Epoch [2/50], Loss: 1.7909, Accuracy: 54.78%


Epoch 3/50: 100%|██████████| 196/196 [00:23<00:00,  8.45it/s, accuracy=58.6, loss=1.59]


Epoch [3/50], Loss: 1.5680, Accuracy: 58.61%


Epoch 4/50: 100%|██████████| 196/196 [00:23<00:00,  8.32it/s, accuracy=60.6, loss=1.51]


Epoch [4/50], Loss: 1.4619, Accuracy: 60.57%


Epoch 5/50: 100%|██████████| 196/196 [00:23<00:00,  8.28it/s, accuracy=62, loss=1.26]  


Epoch [5/50], Loss: 1.3945, Accuracy: 61.95%


Epoch 6/50: 100%|██████████| 196/196 [00:23<00:00,  8.29it/s, accuracy=63.3, loss=1.21]


Epoch [6/50], Loss: 1.3433, Accuracy: 63.28%


Epoch 7/50: 100%|██████████| 196/196 [00:23<00:00,  8.51it/s, accuracy=64.1, loss=1.43]


Epoch [7/50], Loss: 1.3034, Accuracy: 64.11%


Epoch 8/50: 100%|██████████| 196/196 [00:22<00:00,  8.63it/s, accuracy=64.5, loss=1.36]


Epoch [8/50], Loss: 1.2753, Accuracy: 64.54%


Epoch 9/50: 100%|██████████| 196/196 [00:22<00:00,  8.77it/s, accuracy=65.3, loss=1.36]


Epoch [9/50], Loss: 1.2453, Accuracy: 65.26%


Epoch 10/50: 100%|██████████| 196/196 [00:23<00:00,  8.27it/s, accuracy=66, loss=1.53]  


Epoch [10/50], Loss: 1.2234, Accuracy: 66.04%


Epoch 11/50: 100%|██████████| 196/196 [00:23<00:00,  8.28it/s, accuracy=66.2, loss=1.15]


Epoch [11/50], Loss: 1.2041, Accuracy: 66.18%


Epoch 12/50: 100%|██████████| 196/196 [00:23<00:00,  8.37it/s, accuracy=66.6, loss=0.992]


Epoch [12/50], Loss: 1.1903, Accuracy: 66.60%


Epoch 13/50: 100%|██████████| 196/196 [00:21<00:00,  9.33it/s, accuracy=67.1, loss=1.02] 


Epoch [13/50], Loss: 1.1713, Accuracy: 67.13%


Epoch 14/50: 100%|██████████| 196/196 [00:22<00:00,  8.67it/s, accuracy=67.5, loss=1.19] 


Epoch [14/50], Loss: 1.1553, Accuracy: 67.54%


Epoch 15/50: 100%|██████████| 196/196 [00:22<00:00,  8.54it/s, accuracy=68, loss=1.07]   


Epoch [15/50], Loss: 1.1420, Accuracy: 67.96%


Epoch 16/50: 100%|██████████| 196/196 [00:23<00:00,  8.42it/s, accuracy=68.2, loss=0.989]


Epoch [16/50], Loss: 1.1318, Accuracy: 68.19%


Epoch 17/50: 100%|██████████| 196/196 [00:22<00:00,  8.65it/s, accuracy=68.2, loss=1.4]  


Epoch [17/50], Loss: 1.1245, Accuracy: 68.18%


Epoch 18/50: 100%|██████████| 196/196 [00:22<00:00,  8.70it/s, accuracy=68.5, loss=1.2]  


Epoch [18/50], Loss: 1.1098, Accuracy: 68.51%


Epoch 19/50: 100%|██████████| 196/196 [00:23<00:00,  8.33it/s, accuracy=68.6, loss=1.32] 


Epoch [19/50], Loss: 1.1037, Accuracy: 68.62%


Epoch 20/50: 100%|██████████| 196/196 [00:23<00:00,  8.43it/s, accuracy=69, loss=1.14]   


Epoch [20/50], Loss: 1.0954, Accuracy: 69.01%


Epoch 21/50: 100%|██████████| 196/196 [00:23<00:00,  8.46it/s, accuracy=69.1, loss=0.946]


Epoch [21/50], Loss: 1.0860, Accuracy: 69.08%


Epoch 22/50: 100%|██████████| 196/196 [00:23<00:00,  8.49it/s, accuracy=69.5, loss=0.947]


Epoch [22/50], Loss: 1.0754, Accuracy: 69.48%


Epoch 23/50: 100%|██████████| 196/196 [00:22<00:00,  8.69it/s, accuracy=69.7, loss=1.2]  


Epoch [23/50], Loss: 1.0689, Accuracy: 69.75%


Epoch 24/50: 100%|██████████| 196/196 [00:23<00:00,  8.49it/s, accuracy=69.9, loss=1.05] 


Epoch [24/50], Loss: 1.0642, Accuracy: 69.88%


Epoch 25/50: 100%|██████████| 196/196 [00:23<00:00,  8.43it/s, accuracy=70.1, loss=1.19] 


Epoch [25/50], Loss: 1.0559, Accuracy: 70.06%


Epoch 26/50: 100%|██████████| 196/196 [00:23<00:00,  8.38it/s, accuracy=70.3, loss=1.3]  


Epoch [26/50], Loss: 1.0486, Accuracy: 70.29%


Epoch 27/50: 100%|██████████| 196/196 [00:22<00:00,  8.66it/s, accuracy=70.3, loss=0.967]


Epoch [27/50], Loss: 1.0421, Accuracy: 70.34%


Epoch 28/50: 100%|██████████| 196/196 [00:22<00:00,  8.69it/s, accuracy=70.5, loss=0.763]


Epoch [28/50], Loss: 1.0362, Accuracy: 70.54%


Epoch 29/50: 100%|██████████| 196/196 [00:21<00:00,  9.04it/s, accuracy=70.4, loss=1.26] 


Epoch [29/50], Loss: 1.0332, Accuracy: 70.43%


Epoch 30/50: 100%|██████████| 196/196 [00:23<00:00,  8.48it/s, accuracy=70.8, loss=1.42] 


Epoch [30/50], Loss: 1.0244, Accuracy: 70.80%


Epoch 31/50: 100%|██████████| 196/196 [00:23<00:00,  8.28it/s, accuracy=70.8, loss=1.11] 


Epoch [31/50], Loss: 1.0207, Accuracy: 70.78%


Epoch 32/50: 100%|██████████| 196/196 [00:24<00:00,  7.96it/s, accuracy=70.9, loss=1.27] 


Epoch [32/50], Loss: 1.0164, Accuracy: 70.91%


Epoch 33/50: 100%|██████████| 196/196 [00:21<00:00,  9.04it/s, accuracy=71.2, loss=1.26] 


Epoch [33/50], Loss: 1.0120, Accuracy: 71.15%


Epoch 34/50: 100%|██████████| 196/196 [00:23<00:00,  8.39it/s, accuracy=71.3, loss=1.02] 


Epoch [34/50], Loss: 1.0048, Accuracy: 71.31%


Epoch 35/50: 100%|██████████| 196/196 [00:23<00:00,  8.38it/s, accuracy=71.2, loss=0.971]


Epoch [35/50], Loss: 1.0024, Accuracy: 71.25%


Epoch 36/50: 100%|██████████| 196/196 [00:23<00:00,  8.35it/s, accuracy=71.2, loss=0.889]


Epoch [36/50], Loss: 0.9969, Accuracy: 71.24%


Epoch 37/50: 100%|██████████| 196/196 [00:24<00:00,  8.09it/s, accuracy=71.5, loss=1.23] 


Epoch [37/50], Loss: 0.9942, Accuracy: 71.54%


Epoch 38/50: 100%|██████████| 196/196 [00:23<00:00,  8.49it/s, accuracy=71.5, loss=1.12] 


Epoch [38/50], Loss: 0.9903, Accuracy: 71.53%


Epoch 39/50: 100%|██████████| 196/196 [00:23<00:00,  8.25it/s, accuracy=71.7, loss=1.03] 


Epoch [39/50], Loss: 0.9864, Accuracy: 71.66%


Epoch 40/50: 100%|██████████| 196/196 [00:22<00:00,  8.53it/s, accuracy=71.8, loss=1.18] 


Epoch [40/50], Loss: 0.9839, Accuracy: 71.79%


Epoch 41/50: 100%|██████████| 196/196 [00:21<00:00,  9.05it/s, accuracy=72, loss=0.924]  


Epoch [41/50], Loss: 0.9780, Accuracy: 72.02%


Epoch 42/50: 100%|██████████| 196/196 [00:21<00:00,  9.00it/s, accuracy=72, loss=1.1]    


Epoch [42/50], Loss: 0.9769, Accuracy: 72.01%


Epoch 43/50: 100%|██████████| 196/196 [00:21<00:00,  9.16it/s, accuracy=72, loss=1.29]   


Epoch [43/50], Loss: 0.9719, Accuracy: 72.03%


Epoch 44/50: 100%|██████████| 196/196 [00:22<00:00,  8.86it/s, accuracy=72.1, loss=1.08] 


Epoch [44/50], Loss: 0.9697, Accuracy: 72.05%


Epoch 45/50: 100%|██████████| 196/196 [00:22<00:00,  8.67it/s, accuracy=72.2, loss=1.32] 


Epoch [45/50], Loss: 0.9669, Accuracy: 72.23%


Epoch 46/50: 100%|██████████| 196/196 [00:22<00:00,  8.68it/s, accuracy=72.4, loss=1.02] 


Epoch [46/50], Loss: 0.9604, Accuracy: 72.38%


Epoch 47/50: 100%|██████████| 196/196 [00:22<00:00,  8.66it/s, accuracy=72.5, loss=0.836]


Epoch [47/50], Loss: 0.9560, Accuracy: 72.49%


Epoch 48/50: 100%|██████████| 196/196 [00:22<00:00,  8.74it/s, accuracy=72.6, loss=1.04] 


Epoch [48/50], Loss: 0.9524, Accuracy: 72.59%


Epoch 49/50: 100%|██████████| 196/196 [00:21<00:00,  9.09it/s, accuracy=72.4, loss=0.836]


Epoch [49/50], Loss: 0.9527, Accuracy: 72.41%


Epoch 50/50: 100%|██████████| 196/196 [00:22<00:00,  8.59it/s, accuracy=72.5, loss=0.927]

Epoch [50/50], Loss: 0.9474, Accuracy: 72.46%





NameError: name 'evaluate' is not defined

## 2.2 从零开始训练

In [19]:
# 加载未经过任何预训练的 ResNet-18 模型
scratch_model = resnet18(weights=None).to(device)

# 替换全连接层，调整输出维度为 100
num_features = scratch_model.fc.in_features
scratch_model.fc = nn.Linear(num_features, 100).to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(scratch_model.parameters(), lr=1e-3)

# 从零开始训练模型
train_linear_classifier(scratch_model, train_loader, criterion, optimizer, epochs=50)


Epoch 1/50: 100%|██████████| 196/196 [00:39<00:00,  4.91it/s, accuracy=13.4, loss=3.54]


Epoch [1/50], Loss: 3.6796, Accuracy: 13.39%


Epoch 2/50: 100%|██████████| 196/196 [00:40<00:00,  4.85it/s, accuracy=26.6, loss=2.77]


Epoch [2/50], Loss: 2.9125, Accuracy: 26.60%


Epoch 3/50: 100%|██████████| 196/196 [00:39<00:00,  4.99it/s, accuracy=38.9, loss=2.17]


Epoch [3/50], Loss: 2.3094, Accuracy: 38.94%


Epoch 4/50: 100%|██████████| 196/196 [00:37<00:00,  5.22it/s, accuracy=48.1, loss=1.73]


Epoch [4/50], Loss: 1.8884, Accuracy: 48.13%


Epoch 5/50: 100%|██████████| 196/196 [00:39<00:00,  4.96it/s, accuracy=55.7, loss=1.43]


Epoch [5/50], Loss: 1.5784, Accuracy: 55.68%


Epoch 6/50: 100%|██████████| 196/196 [00:39<00:00,  4.91it/s, accuracy=61.8, loss=1.21]


Epoch [6/50], Loss: 1.3235, Accuracy: 61.78%


Epoch 7/50: 100%|██████████| 196/196 [00:39<00:00,  4.97it/s, accuracy=68, loss=0.998]  


Epoch [7/50], Loss: 1.0938, Accuracy: 67.96%


Epoch 8/50: 100%|██████████| 196/196 [00:40<00:00,  4.83it/s, accuracy=74.1, loss=0.983]


Epoch [8/50], Loss: 0.8643, Accuracy: 74.14%


Epoch 9/50: 100%|██████████| 196/196 [00:39<00:00,  5.03it/s, accuracy=80.1, loss=0.838]


Epoch [9/50], Loss: 0.6572, Accuracy: 80.12%


Epoch 10/50: 100%|██████████| 196/196 [00:38<00:00,  5.11it/s, accuracy=86.8, loss=0.777]


Epoch [10/50], Loss: 0.4419, Accuracy: 86.77%


Epoch 11/50: 100%|██████████| 196/196 [00:39<00:00,  4.97it/s, accuracy=92.5, loss=0.494]


Epoch [11/50], Loss: 0.2682, Accuracy: 92.47%


Epoch 12/50: 100%|██████████| 196/196 [00:38<00:00,  5.04it/s, accuracy=96.4, loss=0.229] 


Epoch [12/50], Loss: 0.1482, Accuracy: 96.36%


Epoch 13/50: 100%|██████████| 196/196 [00:39<00:00,  4.99it/s, accuracy=98.4, loss=0.121] 


Epoch [13/50], Loss: 0.0765, Accuracy: 98.44%


Epoch 14/50: 100%|██████████| 196/196 [00:39<00:00,  4.92it/s, accuracy=99.4, loss=0.0394]


Epoch [14/50], Loss: 0.0378, Accuracy: 99.44%


Epoch 15/50: 100%|██████████| 196/196 [00:39<00:00,  4.99it/s, accuracy=99.7, loss=0.0143] 


Epoch [15/50], Loss: 0.0219, Accuracy: 99.71%


Epoch 16/50: 100%|██████████| 196/196 [00:37<00:00,  5.16it/s, accuracy=99.9, loss=0.0115] 


Epoch [16/50], Loss: 0.0116, Accuracy: 99.90%


Epoch 17/50: 100%|██████████| 196/196 [00:40<00:00,  4.87it/s, accuracy=99.9, loss=0.00556]


Epoch [17/50], Loss: 0.0080, Accuracy: 99.93%


Epoch 18/50: 100%|██████████| 196/196 [00:39<00:00,  4.95it/s, accuracy=100, loss=0.00361]


Epoch [18/50], Loss: 0.0038, Accuracy: 99.96%


Epoch 19/50: 100%|██████████| 196/196 [00:40<00:00,  4.83it/s, accuracy=99.9, loss=0.0037] 


Epoch [19/50], Loss: 0.0046, Accuracy: 99.94%


Epoch 20/50: 100%|██████████| 196/196 [00:38<00:00,  5.11it/s, accuracy=89.5, loss=1.22]  


Epoch [20/50], Loss: 0.3582, Accuracy: 89.54%


Epoch 23/50: 100%|██████████| 196/196 [00:39<00:00,  4.96it/s, accuracy=99.8, loss=0.00612]


Epoch [23/50], Loss: 0.0146, Accuracy: 99.84%


Epoch 24/50: 100%|██████████| 196/196 [00:38<00:00,  5.11it/s, accuracy=99.9, loss=0.00466]


Epoch [24/50], Loss: 0.0058, Accuracy: 99.94%


Epoch 25/50: 100%|██████████| 196/196 [00:38<00:00,  5.07it/s, accuracy=100, loss=0.00592]


Epoch [25/50], Loss: 0.0038, Accuracy: 99.96%


Epoch 26/50: 100%|██████████| 196/196 [00:39<00:00,  4.96it/s, accuracy=100, loss=0.0022]  


Epoch [26/50], Loss: 0.0038, Accuracy: 99.95%


Epoch 27/50: 100%|██████████| 196/196 [00:39<00:00,  4.96it/s, accuracy=100, loss=0.00125] 


Epoch [27/50], Loss: 0.0026, Accuracy: 99.96%


Epoch 28/50: 100%|██████████| 196/196 [00:39<00:00,  4.90it/s, accuracy=100, loss=0.00196] 


Epoch [28/50], Loss: 0.0020, Accuracy: 99.97%


Epoch 29/50: 100%|██████████| 196/196 [00:39<00:00,  4.97it/s, accuracy=100, loss=0.000726]


Epoch [29/50], Loss: 0.0016, Accuracy: 99.98%


Epoch 30/50: 100%|██████████| 196/196 [00:37<00:00,  5.20it/s, accuracy=100, loss=0.000821]


Epoch [30/50], Loss: 0.0018, Accuracy: 99.97%


Epoch 31/50:  39%|███▉      | 76/196 [00:15<00:20,  5.73it/s, accuracy=100, loss=0.0123]  IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 36/50: 100%|██████████| 196/196 [00:39<00:00,  4.97it/s, accuracy=100, loss=0.000624]


Epoch [36/50], Loss: 0.0010, Accuracy: 99.98%


Epoch 37/50: 100%|██████████| 196/196 [00:38<00:00,  5.07it/s, accuracy=100, loss=0.000307]


Epoch [37/50], Loss: 0.0009, Accuracy: 99.98%


Epoch 38/50: 100%|██████████| 196/196 [00:39<00:00,  4.96it/s, accuracy=100, loss=0.000367]


Epoch [38/50], Loss: 0.0010, Accuracy: 99.97%


Epoch 39/50: 100%|██████████| 196/196 [00:39<00:00,  4.98it/s, accuracy=100, loss=0.00054] 


Epoch [39/50], Loss: 0.0008, Accuracy: 99.98%


Epoch 40/50: 100%|██████████| 196/196 [00:38<00:00,  5.16it/s, accuracy=100, loss=0.000477]


Epoch [40/50], Loss: 0.0008, Accuracy: 99.98%


Epoch 41/50: 100%|██████████| 196/196 [00:38<00:00,  5.08it/s, accuracy=100, loss=0.000323]


Epoch [41/50], Loss: 0.0009, Accuracy: 99.98%


Epoch 42/50: 100%|██████████| 196/196 [00:38<00:00,  5.06it/s, accuracy=100, loss=0.000351]


Epoch [42/50], Loss: 0.0009, Accuracy: 99.97%


Epoch 43/50: 100%|██████████| 196/196 [00:37<00:00,  5.27it/s, accuracy=100, loss=0.000159]


Epoch [43/50], Loss: 0.0006, Accuracy: 99.98%


Epoch 44/50: 100%|██████████| 196/196 [00:37<00:00,  5.18it/s, accuracy=77.2, loss=0.898] 


Epoch [44/50], Loss: 0.8383, Accuracy: 77.23%


Epoch 45/50: 100%|██████████| 196/196 [00:37<00:00,  5.28it/s, accuracy=90, loss=0.244]  


Epoch [45/50], Loss: 0.3189, Accuracy: 89.96%


Epoch 46/50: 100%|██████████| 196/196 [00:38<00:00,  5.12it/s, accuracy=98.7, loss=0.0268]


Epoch [46/50], Loss: 0.0529, Accuracy: 98.71%


Epoch 47/50: 100%|██████████| 196/196 [00:38<00:00,  5.09it/s, accuracy=99.9, loss=0.0085] 


Epoch [47/50], Loss: 0.0103, Accuracy: 99.89%


Epoch 48/50: 100%|██████████| 196/196 [00:37<00:00,  5.21it/s, accuracy=100, loss=0.00273]


Epoch [48/50], Loss: 0.0038, Accuracy: 99.96%


Epoch 49/50: 100%|██████████| 196/196 [00:36<00:00,  5.34it/s, accuracy=100, loss=0.00219] 


Epoch [49/50], Loss: 0.0026, Accuracy: 99.96%


Epoch 50/50: 100%|██████████| 196/196 [00:38<00:00,  5.05it/s, accuracy=100, loss=0.00182] 

Epoch [50/50], Loss: 0.0018, Accuracy: 99.97%





# 3. 评估有监督与无监督

In [None]:
def evaluate(model, dataloader, criterion):
    model.eval()
    correct = 0
    total = 0
    test_loss = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    accuracy = 100. * correct / total
    test_loss /= len(dataloader)
    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {accuracy:.2f}%')
    return test_loss, accuracy

test_loss, test_accuracy = evaluate(linear_classifier, test_loader, criterion)


In [None]:
test_loss, test_accuracy = evaluate(pretrained_model, test_loader, criterion)

In [None]:
test_loss, test_accuracy = evaluate(scratch_model, test_loader, criterion)