In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import argparse
from torchvision.ops import MultiScaleRoIAlign
from torchvision.datasets import ImageFolder
import os
import torch.nn.functional as F
import torchvision
import numpy as np
from torchvision.models.detection.transform  import  GeneralizedRCNNTransform

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


parser = argparse.ArgumentParser(description='PyTorch GAF_2 Training')

args = parser.parse_known_args()[0]

EPOCH = 50   

transform_train = transforms.Compose([
    transforms.RandomRotation(20),
    transforms.RandomResizedCrop((800,800)),
    transforms.RandomHorizontalFlip(),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])

transform_test = transforms.Compose([
    transforms.Resize((800,800)),    
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])
trainset = ImageFolder(root='../GAF_2_Data/Train/', transform=transform_train) 
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=256)   

testset = ImageFolder(root='../GAF_2_Data/Val/', transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=256)

classes = ('Negative', 'Neutral', 'Positive')


In [3]:
class Net(nn.Module):
    def __init__(self, model):
        super(Net, self).__init__()
        self.backbone = model
        self.fc1_layer = nn.Sequential(nn.Linear(in_features=12544, out_features=1024),nn.ReLU())
        self.fc2_layer = nn.Linear(in_features=1024, out_features=3) 
        self.box_roi_pool = MultiScaleRoIAlign(
                featmap_names=['0', '1', '2', '3'],
                output_size=7,
                sampling_ratio=2)
    def forward(self, x):
        w = x.shape[-2]
        h = x.shape[-1]
        x = self.backbone(x)
        boxes = torch.zeros(x['0'].shape[0],4).to(device)
        boxes[:,2] = w-1
        boxes[:,3] = h-1
        x = self.box_roi_pool(x, [boxes], [(w, h)])
        x = x.view(x.size(0), -1)
        x = self.fc1_layer(x)
        x = self.fc2_layer(x)
        return x 

In [4]:
model = torchvision.models.detection.backbone_utils.resnet_fpn_backbone("resnet50", pretrained=True,trainable_layers=5)
net = Net(model)
net = net.to(device)

In [5]:
print(net)

Net(
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): FrozenBatchNorm2d(256)
          )
        )
        (1): Bottleneck(
    

In [7]:
criterion = nn.CrossEntropyLoss()  #损失函数为交叉熵，多用于多分类问题

params_all = [p for p in net.parameters() if p.requires_grad]
base_params = list(map(id, net.backbone.parameters()))
logits_params = filter(lambda p: id(p) not in base_params, params_all)
pre_epoch = 0
decay = 1
# 训练
if __name__ == "__main__":
    best_acc = 75  
    print("Start Training, Resnet50-FPN!") 
    with open("acc_800.txt", "w") as f:
        with open("log_800.txt", "w")as f2:
            for epoch in range(pre_epoch, EPOCH):
                print('\nEpoch: %d' % (epoch + 1))
                net.train()
                sum_loss = 0.0
                correct = 0.0
                total = 0.0
                if epoch < 1:
                    lr_backbone = 0
                    lr_roialign = 1e-3                    
                else:
                    lr_backbone = 1e-3
                    lr_roialign = 1e-3
                    decay = decay * 0.8
                params = [{'params': logits_params, 'lr':decay*lr_roialign},
                            {'params': net.backbone.parameters(), 'lr':decay*lr_backbone}]   
                
                optimizer = optim.SGD(params, momentum=0.2, weight_decay = 1e-4) 
                
                for i, data in enumerate(trainloader, 0):
                    length = len(trainloader)
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)
                    optimizer.zero_grad()

                    # forward + backward
                    outputs = net(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()

                    sum_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += predicted.eq(labels.data).cpu().sum()
                    print('[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% '
                          % (epoch + 1, (i + 1 + epoch * length), sum_loss / (i + 1), 100. * correct / total))
                    f2.write('%03d  %05d |Loss: %.03f | Acc: %.3f%% '
                          % (epoch + 1, (i + 1 + epoch * length), sum_loss / (i + 1), 100. * correct / total))
                    f2.write('\n')
                    f2.flush()
                torch.cuda.empty_cache()
                # 每训练完一个epoch测试一下准确率
                print("Waiting Test!")
                with torch.no_grad():
                    correct = 0
                    total = 0
                    class_correct = list(0. for i in range(3)) 
                    class_total = list(0. for i in range(3))   
                    for data in testloader:
                        net.eval()
                        images, labels = data
                        images, labels = images.to(device), labels.to(device)
                        outputs = net(images)
                        _, predicted = torch.max(outputs.data, 1)
                        total += labels.size(0)
                        correct += (predicted == labels).sum()
                        c = (predicted == labels).squeeze()
                        label = labels   
                        class_correct[label] += c
                        class_total[label] += 1
                    for i in range(3):
                        print('Accuracy of %5s : %2d %%' % (
                            classes[i], 100 * class_correct[i] // class_total[i]))
                    print('测试分类准确率为：%.3f%%' % (100 * correct // total))
                    acc = 100. * correct / total
                    # 将每次测试结果实时写入acc.txt文件中
                    f.write("EPOCH=%03d,Accuracy= %.3f%%" % (epoch + 1, acc))
                    f.write('\n')
                    f.flush()
                    # 记录最佳测试分类准确率并写入best_acc.txt文件中
                    if acc > best_acc:
                        f3 = open("best_acc_800.txt", "w")
                        print('Saving model......')
                        torch.save(net.state_dict(), 'GAF_2_best_net_image_800.pth')
                        f3.write("EPOCH=%d,best_acc= %.3f%%" % (epoch + 1, acc))
                        f3.close()
                        best_acc = acc
            print("Training Finished, TotalEPOCH=%d" % EPOCH)

Start Training, Resnet50-FPN!

Epoch: 1
[epoch:1, iter:1] Loss: 1.129 | Acc: 0.000% 
[epoch:1, iter:2] Loss: 1.121 | Acc: 0.000% 
[epoch:1, iter:3] Loss: 1.127 | Acc: 0.000% 
[epoch:1, iter:4] Loss: 1.134 | Acc: 0.000% 
[epoch:1, iter:5] Loss: 1.133 | Acc: 0.000% 
[epoch:1, iter:6] Loss: 1.128 | Acc: 0.000% 
[epoch:1, iter:7] Loss: 1.127 | Acc: 0.000% 
[epoch:1, iter:8] Loss: 1.124 | Acc: 0.000% 
[epoch:1, iter:9] Loss: 1.120 | Acc: 11.111% 
[epoch:1, iter:10] Loss: 1.115 | Acc: 20.000% 
[epoch:1, iter:11] Loss: 1.113 | Acc: 18.182% 
[epoch:1, iter:12] Loss: 1.110 | Acc: 25.000% 
[epoch:1, iter:13] Loss: 1.110 | Acc: 23.077% 
[epoch:1, iter:14] Loss: 1.108 | Acc: 28.571% 
[epoch:1, iter:15] Loss: 1.106 | Acc: 33.333% 
[epoch:1, iter:16] Loss: 1.106 | Acc: 31.250% 
[epoch:1, iter:17] Loss: 1.101 | Acc: 35.294% 
[epoch:1, iter:18] Loss: 1.101 | Acc: 33.333% 
[epoch:1, iter:19] Loss: 1.099 | Acc: 31.579% 
[epoch:1, iter:20] Loss: 1.098 | Acc: 35.000% 
[epoch:1, iter:21] Loss: 1.099 | Acc:

Process Process-236:
Process Process-43:
Process Process-134:
Process Process-219:
Process Process-23:
Process Process-11:
Process Process-56:
Process Process-14:
Process Process-6:
Process Process-64:
Process Process-91:
Process Process-40:
Process Process-21:
Traceback (most recent call last):
Traceback (most recent call last):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7efd5a8f35f0>
Traceback (most recent call last):
  File "/usr/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/usr/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/miniconda3/envs/pytorch/lib/python3.7/multiprocessing/process.py", line 140, in join
    res = self._popen.wait(timeout)
  File "/usr/miniconda3/envs/pytorch/lib/python3.7/multiprocessing/popen_fork.

KeyboardInterrupt: 

Process Process-37:
Process Process-50:
Process Process-217:
Process Process-39:
Process Process-42:
Process Process-128:
Process Process-198:
Process Process-235:
Process Process-99:
Process Process-19:
Process Process-71:
Process Process-111:
Process Process-83:
Process Process-256:
Process Process-52:
Process Process-74:
Process Process-241:
Process Process-25:
Process Process-28:
Process Process-80:
Process Process-85:
Traceback (most recent call last):
Process Process-242:
Process Process-35:
Process Process-45:
Process Process-73:
Process Process-135:
Process Process-44:
Traceback (most recent call last):
Process Process-46:
Process Process-255:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Process Process-59:
Traceback (most recent call last):
Process Process-34:
Process Process-210:
Process Process-240:
Process Process-103:
Traceback (most recent call last):
Process Process-27:
  File "/usr/miniconda3/envs/pytorch/lib/p