In [1]:
from torchvision import transforms, models, datasets
from torchvision.transforms.functional import InterpolationMode
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = models.resnet18(pretrained=True)

In [3]:
resize_size = 256
crop_size = 224
mean=(0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)

test_transform = transforms.Compose(
        [
            transforms.Resize(resize_size, interpolation=InterpolationMode.BILINEAR),
            transforms.CenterCrop(crop_size),
            transforms.PILToTensor(),
            transforms.ConvertImageDtype(torch.float),
            transforms.Normalize(mean=mean, std=std),
        ]
    )
test_set = datasets.ImageFolder(f'/home/ubuntu/data/Imagenet/ILSVRC/Data/CLS-LOC/val/', transform=test_transform)
len(test_set)

FileNotFoundError: [Errno 2] No such file or directory: '/home/ubuntu/data/Imagenet/ILSVRC/Data/CLS-LOC/val/'

In [3]:
mean=(0.485, 0.456, 0.406)
crop_size = 224
std=(0.229, 0.224, 0.225)
test_transform = transforms.Compose([
    transforms.CenterCrop((crop_size, crop_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
test_set = datasets.CIFAR10(root=f'/home/ubuntu/data/cifar10', train=False, download=True, transform=test_transform)
train_set = datasets.CIFAR10(root=f'/home/ubuntu/data/cifar10', train=True, download=True, transform=test_transform)
len(test_set), len(train_set)

Files already downloaded and verified
Files already downloaded and verified


(10000, 50000)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [29]:
def append_np(original, to_append):
    if original is None:
        return to_append
    else:
        return np.concatenate((original, to_append))

In [5]:
def get_features(model, dataset, batch_size=32, num_workers=8):
    original_fc = copy.deepcopy(model.fc)
    model.fc = torch.nn.Identity()
    model.to(device)
    model.eval()
    dataloader = DataLoader(dataset,batch_size=batch_size, num_workers=num_workers,shuffle=False)
    all_feats = None
    for (images, label) in tqdm(dataloader, total=len(dataset) // batch_size):
        with torch.no_grad():
            output = model(images.to(device)).cpu().numpy()
            if all_feats is None:
                all_feats = output
            else:
                all_feats = np.concatenate((all_feats, output))
    model.fc = original_fc
    return all_feats

In [6]:
cifar_feats_train = get_features(model, train_set)
cifar_feats_test = get_features(model, test_set)
cifar_feats_train.shape, cifar_feats_test.shape

1563it [01:15, 20.62it/s]                          
313it [00:07, 39.46it/s]                         


((50000, 512), (10000, 512))

In [7]:
def get_logits(model, dataset, batch_size=32, num_workers=8):
    model.to(device)
    model.eval()
    dataloader = DataLoader(dataset,batch_size=batch_size, num_workers=num_workers,shuffle=False)
    all_logits = None
    all_labels = None
    for (images, label) in tqdm(dataloader, total=len(dataset) // batch_size):
        with torch.no_grad():
            output = model(images.to(device)).cpu().numpy()
            if all_logits is None:
                all_logits = output
                all_labels = label.numpy()
            else:
                all_logits = np.concatenate((all_logits, output))
                all_labels = np.concatenate((all_labels, label.numpy()))
    return all_logits, all_labels

In [8]:
cifar_logits_train, cifar_labels_train = get_logits(model, train_set)
cifar_logits_test, cifar_labels_test = get_logits(model, test_set)
cifar_logits_train.shape, cifar_logits_test.shape

1563it [02:00, 13.02it/s]                          
313it [00:08, 35.56it/s]                         


((50000, 1000), (10000, 1000))

In [12]:
def learn_reconstruct(logits, features, test_logits, test_features):
    lin = torch.nn.Linear(logits.shape[1], features.shape[1])
    rec_loss = torch.nn.MSELoss()
    epochs = 100
    batch_size = 16
    optimizer = torch.optim.SGD(lin.parameters(), lr = 0.1, momentum = 0.9)
    
    logits = torch.FloatTensor(logits)
    features = torch.FloatTensor(features)
    test_logits = torch.FloatTensor(test_logits)
    test_features = torch.FloatTensor(test_features)

    for i in range(epochs):
        total_loss = 0
        for b in range(0, len(logits), batch_size):
            optimizer.zero_grad()
            x = logits[b:b+batch_size]
            y = features[b:b+batch_size]
            pred_y = lin(x)
            loss = rec_loss(pred_y, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if i % 10 == 0:
            lin.eval()
            test_loss = 0
            for b in range(0, len(test_logits), batch_size):
                with torch.no_grad():
                    x = test_logits[b:b+batch_size]
                    y = test_features[b:b+batch_size]
                    pred_y = lin(x)
                    loss = rec_loss(pred_y, y)
                    test_loss += loss.item()
            print(f"Epoch {i}, train loss={total_loss}, test loss = {test_loss}")
            lin.train()
            
    return lin

In [13]:
recon_model = learn_reconstruct(cifar_logits_train, cifar_feats_train, cifar_logits_test, cifar_feats_test)

Epoch 0, train loss=71.39638428250328, test loss = 5.738541886210442
Epoch 10, train loss=6.093450883287005, test loss = 1.5566608069930226
Epoch 20, train loss=3.9516613416490145, test loss = 1.0289810613030568
Epoch 30, train loss=2.9362547773635015, test loss = 0.7705695214681327
Epoch 40, train loss=2.318786834453931, test loss = 0.6127101713209413
Epoch 50, train loss=1.897863550373586, test loss = 0.5052875905530527
Epoch 60, train loss=1.5909859931562096, test loss = 0.42697102518286556
Epoch 70, train loss=1.3570018215395976, test loss = 0.36705848306883126
Epoch 80, train loss=1.1727529991185293, test loss = 0.31959150033071637
Epoch 90, train loss=1.024091051222058, test loss = 0.2809903149609454


In [28]:
def do_feat_recon(logits, recon_model_):
    batch_size = 16
    logits = torch.FloatTensor(logits)

    all_feats = None
    with torch.no_grad():
        for b in range(0, len(logits), batch_size):
            x = logits[b:b+batch_size]
            pred_y = recon_model_(x).numpy()
            all_feats = append_np(all_feats, pred_y)
    return all_feats

In [30]:
recon_cifar_feats_test = do_feat_recon(cifar_logits_test, recon_model)
recon_cifar_feats_test.shape

(10000, 512)

In [18]:
def get_preds(feats, lin):
    lin.eval()
    batch_size=16
    all_preds = None
    feats = torch.FloatTensor(feats)
    for b in range(0, len(feats), batch_size):
        with torch.no_grad():
            x = feats[b:b+batch_size]
            pred = lin(x).numpy()
            if all_preds is None:
                all_preds = pred
            else:
                all_preds = np.concatenate((all_preds, pred))
    return all_preds

In [34]:
preds_orig = get_preds(cifar_feats_test, model.fc.cpu())
preds_recon = get_preds(recon_cifar_feats_test,  model.fc.cpu())
preds_orig.shape, preds_recon.shape

((10000, 1000), (10000, 1000))

In [44]:
top_recon = np.argmax(preds_recon, axis=1)
top_orig = np.argmax(preds_orig, axis=1)
np.mean(top_orig == top_recon)

0.9918

Given a resnet18, pretrained on ImageNet, learns the features from the logits with a single linear layer. 
Training on CIFAR-100 train, and testing on CIFAR-100 test, we get a reconstruction, MSE loss of 0.28 total over the 10k test examples.
If we take these reconstructed features and then use the model's linear head, the top-1 prediction matches the original prediction 99.18% of the time.