In [7]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision import datasets
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import pickle

In [8]:
class VGG11Conv(nn.Module):
    def __init__(self):
        super(VGG11Conv, self).__init__()
        vgg11 = models.vgg11(pretrained=True)
        self.features = vgg11.features

    def forward(self, x):
        x = self.features(x)
        return x

In [9]:
test_dir = './data'
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = datasets.ImageFolder(test_dir, transform=transform)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)


In [22]:
def evaluate_vgg11(model, data_loader):

    model.eval()
    x = []
    y = []

    with torch.no_grad():
        for images, labels in data_loader:
            outputs = model(images)
            x.append(outputs.data)
            y.append(labels)
        x = torch.cat(x, dim=0)
        y = torch.cat(y, dim=0)

    return x, y

In [None]:
model = VGG11Conv()
x, y = evaluate_vgg11(model, data_loader)
with open("data_vgg1", 'wb') as file:
    pickle.dump([x,y], file)