In [1]:
import torch
import torchvision
from torchvision import transforms
from editor import editor18
import matplotlib.pyplot as plt
from pretext import GrayScalePL, SuperResolutionPL, RandomPatchPL, RealImagePL
from pretext import RandomPretextConverter
import numpy as np

# Qualitative

In [9]:
model = editor18(3)
model_path = './../model/CENC/STL-10/exp4/E_epoch_8'
model.load_state_dict(torch.load(model_path))
model.eval()

Editor(
  (encoder): ResNetEncoder(
    (gate): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (blocks): ModuleList(
      (0): ResNetLayer(
        (blocks): Sequential(
          (0): ResNetBasicBlock(
            (blocks): Sequential(
              (0): Sequential(
                (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              )
              (1): ReLU(inplace=True)
              (2): Sequential(
                (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              )
            )
            (activate): 

In [10]:
transform = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
    transforms.Normalize([0.447, 0.440, 0.407], [0.260, 0.256, 0.271])
])

# load test data
data_dir = './../data/supervised/test'
test_dataset = torchvision.datasets.ImageFolder(data_dir, transform)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)

In [11]:
class NormalizeInverse(torchvision.transforms.Normalize):
    def __init__(self, mean, std):
        mean = torch.as_tensor(mean)
        std = torch.as_tensor(std)
        std_inv = 1 / std
        mean_inv = -mean * std_inv
        super().__init__(mean=mean_inv, std=std_inv)

    def __call__(self, tensor):
        return super().__call__(tensor.clone())

unnormalize = NormalizeInverse([0.447, 0.440, 0.407], [0.260, 0.256, 0.271])

In [12]:
# testing some sample
def imshow(img):
    #img = img / 2 + 0.5     # unnormalize
    #print(img.shape)
    img = unnormalize(img)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis("off")
    plt.rcParams["figure.figsize"] = (96,96)
    plt.show()

In [13]:
images_to_be_saved = []

In [14]:
tasks = [
    RealImagePL,
    GrayScalePL,
    SuperResolutionPL,
    RandomPatchPL
]
for i, (real_image, _) in enumerate(testloader):
    inp = []
    
    np_real_image = np.squeeze(real_image.numpy())
    np_real_image = np.transpose(np_real_image, (1,2,0))
    
    inp.append(RealImagePL(np_real_image))
    inp.append(GrayScalePL(np_real_image))
    inp.append(SuperResolutionPL(np_real_image))
    inp.append(RandomPatchPL(np_real_image))
    
    inp = np.array(inp)
    inp = np.transpose(inp, (0, 3, 1, 2))
    inp = torch.from_numpy(inp).float()
    
    out = model(inp).detach()
    final = [torchvision.utils.make_grid(inp, padding = 0), torchvision.utils.make_grid(out, padding = 0)]
    images_to_be_saved.append(torchvision.utils.make_grid(final, nrow=1, padding = 0))
    if i == 1:
        break;

In [15]:
for i, image in enumerate(images_to_be_saved):
    fn = './../data/output/' + str(i+1) + '.png'
    image = unnormalize(image)
    torchvision.utils.save_image(image, fn, padding = 0)

# Quantitave 

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.447, 0.440, 0.407], [0.260, 0.256, 0.271])
])

# load test data
data_train = './../data/supervised/train'
data_test = './../data/supervised/test'

trainset = torchvision.datasets.ImageFolder(data_train, transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True)

testset = torchvision.datasets.ImageFolder(data_test, transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

In [None]:
model = editor18(3)
model_path = './../model/CENC/STL-10/exp2/E_epoch_199'
model.load_state_dict(torch.load(model_path))

In [None]:
class Flatten(torch.nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

In [None]:
net = torch.nn.Sequential(
    *list(model.children())[:1], # only encoder
    Flatten(),
    torch.nn.Linear(6*6*512, 10)
)
for child in net.children():
    for param in child.parameters():
        param.requires_grad = False
    break
    
param_to_update = filter(lambda p: p.requires_grad, net.parameters())
optimizer = torch.optim.Adam(param_to_update, lr=0.001)

In [None]:
model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
    param.requires_grad = False

num_ftrs = model_conv.fc.in_features
model_conv.fc = torch.nn.Linear(num_ftrs, 10)

optimizer_conv = torch.optim.Adam(model_conv.fc.parameters(), lr=0.001)

In [None]:
criterion = torch.nn.CrossEntropyLoss()

In [None]:
def train(model, optimizer, criterion):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader):
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    return running_loss

In [None]:
def test(model):
    model.eval()
    correct = 0.0
    total = 0.0
    for data in testloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    return correct/total

In [None]:
for epoch in range(100):
    our_loss = train(net, optimizer, criterion)
    res_loss = train(model_conv, optimizer_conv, criterion)
    our_acc = test(net)
    res_acc = test(model_conv)
    print('Epoch %d: Tansen [loss: %.3f, acc: %2d %%]; Resnet18 [loss: %.3f, acc: %2d %%]' % (epoch+1, our_loss, 100*our_acc, res_loss, 100*res_acc))