In [None]:
cd ../src

In [None]:
from d07_visualization.viz_training import plot_acc, plot_training_loss, plot_losses
from d04_mixmatch.wideresnet import WideResNet
from d02_data.load_data import get_dataloaders_ssl
from d02_data.load_data_idxs import get_dataloaders_with_index
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [None]:
model_name = 'sgd_250/checkpoint_30000_unlearns.pt'
saved_model = torch.load(f'../models/{model_name}', map_location=torch.device('cpu')) 

model = WideResNet(depth=28, k=2, n_out=10, bias=False)
model.load_state_dict(saved_model['model_state_dict'])

# optim = ...
# optim.load_state_dict(saved_model['optimizer_state_dict'])

loss_train = saved_model['loss_train']
loss_val = saved_model['loss_val']
acc_train = saved_model['acc_train']
acc_val = saved_model['acc_val']

loss_batch = saved_model['loss_batch']
lx = saved_model['lx']
lu = saved_model['lu']
lu_weighted = saved_model['lu_weighted']

saved_model.keys()

In [None]:
plot_training_loss(loss_train, loss_val, step=1000)
plt.show()
plot_training_loss(loss_train[2:], loss_val[2:], step=1000)


In [None]:
plt.figure(figsize=(8,6))
print('Max val acc: ' + str(max(acc_val)))
plot_acc(acc_train, acc_val, step=1000)

In [None]:
plt.figure(figsize=(8,6))
plot_losses(loss_batch[50:], lx[50:], lu[50:], lu_weighted[50:])
plt.show()
idx1 = 9000
idx2 = 11000
plt.figure(figsize=(10,8))
plot_losses(loss_batch[idx1:idx2], lx[idx1:idx2], lu[idx1:idx2], lu_weighted[idx1:idx2])

In [None]:
plt.figure(figsize=(8,6))

kernel_size = 10
kernel = np.ones(kernel_size) / kernel_size
loss_batch_f = np.convolve(loss_batch, kernel)
lx_f = np.convolve(lx, kernel)
lu_f = np.convolve(lu, kernel)
lu_weighted_f = np.convolve(lu_weighted, kernel)

plot_losses(loss_batch_f, lx_f, lu_f, lu_weighted_f)
fig, ax1 = plt.subplots()

ax2 = ax1.twinx()
ax2.plot(acc_val, 'b-')

plt.show()

In [None]:
id = np.argmax(lx[9500:10000])
print(id)
print(np.round(lx[9780:9790],2))


In [None]:
def evaluate(dataloader, adam=True):
    criterion = nn.CrossEntropyLoss()
    ema_model.eval()
    correct, total, loss = 0, 0, 0
    with torch.no_grad():
        for i, data in enumerate(dataloader, 0):
            inputs, labels = data[0], data[1]
            if adam:
                outputs = ema_model(inputs)
            else:
                outputs = model(inputs)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    acc = correct / total * 100
    return loss, acc

_, _, val_loader, test_loader, _, _, _ = get_dataloaders_ssl(path='../data', batch_size=64, num_labeled=250)

In [None]:
ema_model = WideResNet(depth=28, k=2, n_out=10, bias=False)
ema_model.load_state_dict(saved_model['ema_state_dict'])

test_loss, test_acc = evaluate(test_loader, adam=True)
print("Test accuracy: %.2f" % (test_acc))

In [None]:
val_loss, val_acc = evaluate(val_loader)
print("Val accuracy: %.2f" % (val_acc))

In [None]:
a = np.array([])
print(a.shape)


In [None]:
loss_list = []
for i in range(10):
    loss_list.append(round(np.mean(lu[i*1000:(i+1)*1000]),3))
print(loss_list)

## Pseudo Labelling

In [None]:
cd ../src

In [None]:
from d02_data.load_data_idxs import get_dataloaders_with_index
lbl_loader, unlbl_loader, _, _, _, unlbl_indxs, _ = get_dataloaders_with_index(path='../data', batch_size=64, num_labeled=250)

In [None]:
id<

In [None]:
unlbl_batch = iter(unlbl_loader).next()
len(unlbl_batch)
print(unlbl_batch[2])

In [None]:
unlbl_loader.dataset.targets[36695]

In [None]:
import numpy as np
a = np.array([-2, 1, 5, 3, 8, 5, 6])
b = np.array([1, 2, 5])
print(list(a[b]))
# Result:
[1, 5, 5]

In [None]:
lbl_loader.dataset.targets[[0,3]]

In [None]:
unlbl_batch = iter(unlbl_loader).next()[0]

p_out_max = torch.tensor([])
lbls = torch.tensor([])
for _ in range(1):
    unlbl_batch = iter(unlbl_loader).next()[0]

    p_out = torch.softmax(model(unlbl_batch), dim=1)
    p_out = p_out.detach()
    p_out_max2, lbls2 = torch.max(p_out, dim=1)
    p_out_max = torch.cat((p_out_max, p_out_max2))
    lbls = torch.cat((lbls, lbls2))

In [None]:

bb = torch.vstack((p_out_max, lbls, lbls)).T
b = torch.cat((b, bb), dim=0)
b.shape

In [None]:
b2 = torch.vstack((b.T, torch.zeros(192))).T
b2.shape

In [None]:
a = b[:,b[0] > 0.7]
a
l = a[0].tolist().extend(a[1].tolist())
l

In [None]:
plt.hist(p_out_max, bins=20);
print('Percentage superior to threshlold: ', str(np.sum(p_out_max>0.95) / p_out_max.shape[0]))