In [1]:
from pathlib import Path
import cv2
import collections
import numpy as np
import torchvision
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
from torch.nn import functional as f
from PIL import Image
from matplotlib import pyplot as plt
import time
from moco_model import *
import json

In [2]:
DATA_DIR = 'E:/Mine/亲手/Python/Pytorch/动手学深度学习v2/data/cifar-10-batches-py/train'

In [3]:
def image_pair_matching(net, original_image, matching_image, mode='training'):
    if mode == 'training':
        q = net.encoder_q(original_image)
        q = f.normalize(q, dim=1)
        k = net.encoder_k(matching_image)
        k = f.normalize(k, dim=1)
        logits = torch.einsum('nc,ck->nk', [q, k.T])
        return logits

def cal_accuracy(preds, label):
    return float(torch.sum(torch.argmax(preds, dim=1).type_as(label) == label))

class Accumulator:
    """For accumulating sums over `n` variables."""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

class HistoryRecorder:
    def __init__(self, names):
        self.data = {name: [] for name in names}
        self.names = names

    def add(self, *args):
        for name, value in zip(self.names, args):
            self.data[name].append(value)

    def reset(self):
        self.data = {name: [] for name in self.names}

    def __getitem__(self, name):
        return self.data[name]

def train_moco_return_metrics(net, criterion, optimizer, epochs, device):
    train_metrics = HistoryRecorder(['Train Loss', 'Train Acc', 'Val Loss', 'Val Acc'])
    idx_to_label_queue = torch.ones((1, net.K), device=device) * -1
    queue_pointer = torch.ones(1, dtype=torch.long)
    for epoch in range(epochs):
        net.cuda(device)
        total_loss = 0
        training_correct = 0
        training_size = 0
        for origin, target, label in train_iter:
            net.train()
            origin, target, label = origin.cuda(device), target.cuda(device), label.cuda(device)
            pointer = int(queue_pointer)
            idx_to_label_queue[pointer: pointer + origin[0].shape[0]] = label
            queue_pointer[0] = (pointer + origin[0].shape[0]) % net.K
            output, labels = net(origin, target)
            loss = criterion(output, labels)
            total_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            net.eval()
            with torch.no_grad():
                training_correct += cal_accuracy(image_pair_matching(net, origin, idx_to_label_queue), label)
                training_size += origin.shape[0]
        net.eval()
        val_loss = 0
        for origin, target, label in val_iter:
            origin, target, label = origin.cuda(device), target.cuda(device), label.cuda(device)
            output, labels = net(origin, target, evaluate=True)
            val_loss += f.cross_entropy(output, labels).item()
            val_correct = cal_accuracy(image_pair_matching(net, origin, target_tensor), label)
        val_acc = val_correct / origin.shape[0]
        train_metrics.add(total_loss / len(train_iter), training_correct / training_size, val_loss / len(val_iter), val_acc)
        print(f'Epoch {epoch + 1}, Loss {total_loss / len(train_iter)}, Train_acc {training_correct / training_size}, Val_loss {val_loss / len(val_iter)}, Val_acc {val_acc}')
    return train_metrics

In [15]:
original_images = list(Path(DATA_DIR).rglob('*.jpg'))
len(original_images)
original_images[0].name.split('.')[0]

'airplane_10029'

In [14]:
class CIFAR10ContrastiveLearning(Dataset):
    def __init__(self):
        super(CIFAR10ContrastiveLearning, self).__init__()
        original_images = list(Path(DATA_DIR).rglob('*.jpg'))
        self.origins = []
        normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        augmentation = [torchvision.transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
                        torchvision.transforms.RandomApply([torchvision.transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
                        torchvision.transforms.RandomGrayscale(p=0.2),
                        torchvision.transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
                        torchvision.transforms.RandomHorizontalFlip(),
                        torchvision.transforms.ToTensor(),
                        normalize
        ]
        compose = torchvision.transforms.Compose(augmentation)
        self.transform = TwoCropsTransform(compose)
        label_counter = 0
        for original_image in original_images:
            self.origins.append((str(original_image), label_counter))
            label_counter += 1
        # random_index = np.random.permutation(len(origins))
        # self.origins, self.targets, self.labels = [], [], []
        # for index in random_index:
        #     self.origins.append(origins[index])
        #     self.targets.append(targets[index])
        #     self.labels.append(labels[index])

    def __getitem__(self, idx):
        path, label = self.origins[idx]
        with open(path, 'rb') as f:
            sample = Image.open(f)
            sample = sample.convert('RGB')
        origin = self.transform(sample)
        target = self.transform(sample)
        return origin, target, label

    def __len__(self):
        return len(self.origins)

In [15]:
dataset = CIFAR10ContrastiveLearning()
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [42500, 7500])
train_iter = DataLoader(train_dataset, batch_size=256, shuffle=True, drop_last=True, pin_memory=True)
val_iter = DataLoader(val_dataset, batch_size=256, drop_last=True, pin_memory=True)
metrics = {}
k_values = [16384]


In [23]:
for origin, target, label in train_iter:
    print(origin[0].shape)
    print(target[0].shape)
    print(label.shape)
    break

torch.Size([256, 3, 224, 224])
torch.Size([256, 3, 224, 224])
torch.Size([256])
