In [11]:
import torch
import numpy as np
import imgaug.augmenters as iaa

from torch import nn

In [4]:
def get_augmenter():
    seq = iaa.Sequential([
        iaa.Crop(px=(0, 16)),
        iaa.Fliplr(0.5),
        iaa.GaussianBlur(sigma=(0, 3.0))
    ])
    
    def augment(images):
        return seq.augment(images.transpose(0, 2, 3, 1)).transpose(0, 2, 3, 1)
    
    return augment

In [5]:
def sharpen(x, T):
    temp = x**(1/T)
    return temp / temp.sum(axis=1, keepdims=True)

In [6]:
def mixup(x1, x2, y1, y2, alpha):
    beta = np.random.beta(alpha, -alpha)
    x = beta * x1 + (1 - beta) * x2
    y = beta * y1 + (1 - beta) * y2
    return x, y

In [7]:
def mixmatch(x, y, u, model, augment_fn, T=0.5, K=2, alpha=0.75):
    xb = augment_fn(x)
    ub = [augment_fn(u) for _ in range(K)]
    qb = sharpen(sum(map(lambda i: model(i), ub)) / K, T)
    Ux = np.concatenate(ub, axis=0)
    Uy = np.concatenate([qb for _ in range(K)], axis=0)
    indices = np.random.shuffle(np.arange(len(xb) + len(Ux)))
    Wx = np.concatenate([Ux, xb], axis=0)[indices]
    Wy = np.concatenate([qb, y], axis=0)[indices]
    X, p = mixup(xb, Wx[:len(xb)], y, Wy[:len(xb)], alpha)
    U, q = mixup(Ux, Wx[len(xb):], Uy, Wy[len(xb):], alpha)
    return X, U, p, q

In [8]:
class MixMatchLoss(torch.nn.Module):
    def __init__(self, lambda_u=100):
        self.lambda_u = lambda_u
        self.xent = torch.nn.CrossEntropyLoss()
        self.mse = torch.nn.MSELoss()
        super(MixMatchLoss, self).__init__()
    
    def forward(self, X, U, p, q, model):
        X_ = np.concatenate([X, U], axis=1)
        preds = model(X_)
        return self.xent(preds[:len(p)], p) + self.lambda_u * self.mse(preds[len(p):], q)

In [12]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

        self._init_weights()
        
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.zeros_(m.bias)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = F.avg_pool2d(x, 2)
        return x
    
class Classifier(nn.Module):
    def __init__(self, num_classes=1000): # <======== modificaition to comply fast.ai
        super().__init__()
        
        self.conv = nn.Sequential(
            ConvBlock(in_channels=3, out_channels=64),
            ConvBlock(in_channels=64, out_channels=128),
            ConvBlock(in_channels=128, out_channels=256),
            ConvBlock(in_channels=256, out_channels=512),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # <======== modificaition to comply fast.ai
        self.fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(512, 128),
            nn.PReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes),
        )

    def forward(self, x):
        x = self.conv(x)
        #x = torch.mean(x, dim=3)   # <======== modificaition to comply fast.ai
        #x, _ = torch.max(x, dim=2) # <======== modificaition to comply fast.ai
        x = self.avgpool(x)         # <======== modificaition to comply fast.ai
        x = self.fc(x)
        return x

In [20]:
def borrowed_model(pretrained=False, **kwargs):
    return Classifier(**kwargs)

In [13]:
def basic_generator(x, y=None, batch_size=32, shuffle=True):
    i = 0
    all_indices = np.random.shuffle(np.arange(len(x))) if shuffle else \
                                                               np.arange(len(x))
    while(True):
        indices = all_indices[i:i+batch_size]
        if y is not None:
            yield x[indices], y[indices]
        yield x[indices]
        i = (i + batch_size) % len(x)

In [14]:
def mixmatch_wrapper(x, y, u, model, batch_size=32):
    augment_fn = get_augmenter()
    train_generator = basic_generator(x, y, batch_size)
    unlabeled_generator = basic_generator(u, batch_size=batch_size)
    while(True):
        xi, yi = next(train_generator)
        ui = next(unlabeled_generator)
        yield mixmatch(xi, yi, ui, model, augment_fn)

In [15]:
def to_torch(*args, device='cuda'):
    convert_fn = lambda x: torch.from_numpy(x).to(device)
    return list(map(convert_fn, args))

In [16]:
def test(model, test_gen, test_iters):
    acc = []
    for i, (x, y) in enumerate(test_gen):
        x = to_torch(x)
        pred = model(x).to('cpu').argmax(axis=1)
        acc.append(np.mean(pred == y.argmax(axis=1)))
        if i == test_iters:
            break
    print('Accuracy was : {}'.format(np.mean(acc)))

In [17]:
def report(loss_history):
    print('Average loss in last epoch was : {}'.format(np.mean(loss_history)))
    return []

In [18]:
def save(model, iter, train_iters):
    torch.save(model.state_dict(), 'model_{}.pth'.format(train_iters // iters))

In [19]:
def run(model, train_gen, test_gen, epochs, train_iters, test_iters, device):
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = MixMatchLoss()
    loss_history = []
    for i, (x, u, p, q) in enumerate(train_gen):
        if i % train_iters == 0:
            loss_history = report(loss_history)
            test(model, test_gen, test_iters)
            save(model, i, train_iters)
            if i // train_iters == epochs:
                return
        else:
            optim.zero_grad()
            x, u, p, q = to_torch(x, u, p, q, device=device)
            loss = loss_fn(x, u, p, q, model)
            loss.backward()
            optim.step()
            loss_history.append(loss.to('cpu'))

In [None]:
run(borrowed_model, )