In [57]:
import clip
import scipy
import torch
from torchvision import *
from torch.utils.data import Dataset, DataLoader
import numpy as np

# Verify the download by listing the directory contents.
# data = datasets.Caltech101(root = "./data", target_type = "category", download = True)

In [10]:
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [12]:
data = datasets.Caltech101(root = "./data", target_type = "category", download = True)

In [8]:
categories = sorted(data.categories)  # Sort for consistency
categories

['Faces',
 'Faces_easy',
 'Leopards',
 'Motorbikes',
 'accordion',
 'airplanes',
 'anchor',
 'ant',
 'barrel',
 'bass',
 'beaver',
 'binocular',
 'bonsai',
 'brain',
 'brontosaurus',
 'buddha',
 'butterfly',
 'camera',
 'cannon',
 'car_side',
 'ceiling_fan',
 'cellphone',
 'chair',
 'chandelier',
 'cougar_body',
 'cougar_face',
 'crab',
 'crayfish',
 'crocodile',
 'crocodile_head',
 'cup',
 'dalmatian',
 'dollar_bill',
 'dolphin',
 'dragonfly',
 'electric_guitar',
 'elephant',
 'emu',
 'euphonium',
 'ewer',
 'ferry',
 'flamingo',
 'flamingo_head',
 'garfield',
 'gerenuk',
 'gramophone',
 'grand_piano',
 'hawksbill',
 'headphone',
 'hedgehog',
 'helicopter',
 'ibis',
 'inline_skate',
 'joshua_tree',
 'kangaroo',
 'ketch',
 'lamp',
 'laptop',
 'llama',
 'lobster',
 'lotus',
 'mandolin',
 'mayfly',
 'menorah',
 'metronome',
 'minaret',
 'nautilus',
 'octopus',
 'okapi',
 'pagoda',
 'panda',
 'pigeon',
 'pizza',
 'platypus',
 'pyramid',
 'revolver',
 'rhino',
 'rooster',
 'saxophone',
 'sc

In [29]:
labels = [data[i][1] for i in range(len(data))]
labels[0]

0

In [14]:
model, preprocess = clip.load("ViT-B/32", device=device)

100%|███████████████████████████████████████| 338M/338M [00:30<00:00, 11.5MiB/s]


In [18]:
batch_size = 32
num_images = len(data)
image_embeddings = []
for i in range(0, num_images, batch_size):
    if i % (320) == 0:
        print(f"{i}/{num_images}")
    preprocessed_images = [preprocess(data[i][0]) for i in range(i, min(i+batch_size, num_images))]
    image_batch = torch.stack(preprocessed_images).to(device)
    with torch.no_grad():
        image_emb_batch = model.encode_image(image_batch).cpu()
    image_embeddings.append(image_emb_batch)

0/8677
320/8677
640/8677
960/8677
1280/8677
1600/8677
1920/8677
2240/8677
2560/8677
2880/8677
3200/8677
3520/8677
3840/8677
4160/8677
4480/8677
4800/8677
5120/8677
5440/8677
5760/8677
6080/8677
6400/8677
6720/8677
7040/8677
7360/8677
7680/8677
8000/8677
8320/8677
8640/8677


In [19]:
image_embeddings = torch.cat(image_embeddings)
image_embeddings.shape

torch.Size([8677, 512])

In [45]:
class ClipEmbeddingsDataset(Dataset):
    def __init__(self, image_embeddings, labels):
        self.image_embeddings = image_embeddings
        self.labels = labels

    def __len__(self):
        return len(self.image_embeddings)

    def __getitem__(self, idx):
        return self.image_embeddings[idx], self.labels[idx]

In [52]:
embeddings_dataset = ClipEmbeddingsDataset(image_embeddings, labels)
embeddings_dataset.__getitem__(8000)

(tensor([ 3.5840e-02, -1.6151e-01, -3.0853e-01, -7.1587e-01, -5.9231e-02,
          4.3419e-02,  1.0449e-01,  6.7824e-01,  5.1617e-01,  6.1116e-01,
          3.2395e-02, -1.3369e-01,  2.3393e-01, -4.5477e-01, -5.9802e-01,
          7.1524e-02,  9.0372e-01,  4.9348e-01,  5.6496e-02,  5.2156e-01,
         -7.3129e-01,  1.3309e-01, -4.1486e-01, -4.9551e-01,  2.3558e-01,
          3.8464e-01, -1.5420e-01, -4.8332e-01,  1.9169e-01, -2.1391e-01,
          5.5755e-02, -4.0705e-01,  5.8370e-01, -1.6488e-02,  1.1376e-01,
         -3.4988e-01,  4.0409e-02, -8.9154e-02,  2.6772e-01,  3.4514e-01,
         -5.1853e-01,  3.5546e-01,  7.8249e-01,  5.1927e-02,  8.3776e-02,
         -1.7630e+00,  5.9644e-03, -1.8441e-02,  5.7882e-01, -2.2634e-01,
         -2.9172e-01,  2.3713e-01,  6.4021e-02, -3.9163e-01, -1.0326e-01,
         -1.6542e-02, -3.1119e-02, -4.5635e-01, -1.7915e-01, -5.9720e-01,
          6.0462e-01, -2.6313e-01, -5.2777e-01, -4.9555e-01, -3.1327e-01,
          2.8816e-02, -1.1795e-01, -3.

In [64]:
#split into train and test
train_size = int(0.8 * len(embeddings_dataset))
test_size = len(embeddings_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(embeddings_dataset, [train_size, test_size])
train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=10, shuffle=True)

In [65]:
train_dataloader.batch_size

10

In [67]:
import torch.nn as nn
import torch.nn.functional as F

class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.layer1 = nn.Linear(512, 256)
        self.layer2 = nn.Linear(256, 256)
        self.layer3 = nn.Linear(256, 101)
        self._init_weights()
        
    def __call__(self, x):
        return self.forward(x)
    
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = self.layer3(x)
        return x
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)  # Applies Xavier initialization
                if m.bias is not None:
                    nn.init.zeros_(m.bias)  # Sets biases to zero


In [68]:
import torch.optim as optim
model = NeuralNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = model.to(device)
optimizer = optim.SGD(model.parameters(), lr = 0.01)
criterion = nn.CrossEntropyLoss()
n_epochs= 100
for epoch in range(n_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(inputs)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()  # Add the loss value for monitoring
    print(f"Epoch [{epoch+1}/{n_epochs}], Loss: {running_loss/len(train_dataloader):.4f}")


cpu
Epoch [1/100], Loss: 2.9809
Epoch [2/100], Loss: 1.3570
Epoch [3/100], Loss: 0.6709
Epoch [4/100], Loss: 0.4299
Epoch [5/100], Loss: 0.3179
Epoch [6/100], Loss: 0.2513
Epoch [7/100], Loss: 0.2126
Epoch [8/100], Loss: 0.1807
Epoch [9/100], Loss: 0.1588
Epoch [10/100], Loss: 0.1418
Epoch [11/100], Loss: 0.1278
Epoch [12/100], Loss: 0.1172
Epoch [13/100], Loss: 0.1053
Epoch [14/100], Loss: 0.0958
Epoch [15/100], Loss: 0.0901
Epoch [16/100], Loss: 0.0804
Epoch [17/100], Loss: 0.0748
Epoch [18/100], Loss: 0.0700
Epoch [19/100], Loss: 0.0663
Epoch [20/100], Loss: 0.0606
Epoch [21/100], Loss: 0.0584
Epoch [22/100], Loss: 0.0535
Epoch [23/100], Loss: 0.0534
Epoch [24/100], Loss: 0.0467
Epoch [25/100], Loss: 0.0459
Epoch [26/100], Loss: 0.0429
Epoch [27/100], Loss: 0.0419
Epoch [28/100], Loss: 0.0386
Epoch [29/100], Loss: 0.0378
Epoch [30/100], Loss: 0.0352
Epoch [31/100], Loss: 0.0336
Epoch [32/100], Loss: 0.0335
Epoch [33/100], Loss: 0.0303
Epoch [34/100], Loss: 0.0302
Epoch [35/100], Los

In [70]:
input, label = next(iter(train_dataloader))
input.shape, label

(torch.Size([10, 512]), tensor([94,  0, 45, 50,  1, 92, 71,  5,  3, 53]))

In [76]:
with torch.no_grad():
    output = model(input)
output
# torch.argmax(model(input), 1)

tensor([[ -6.3651,  -2.2537,  -3.9694,  ...,  -4.5904,   6.2325,   3.2399],
        [ 19.9324,   8.4823,  -1.3478,  ...,  -2.6515,   4.0993,  -1.7972],
        [ -3.3025,  -3.6270,  -6.5164,  ...,   2.1753,  -3.0322,   1.5346],
        ...,
        [  1.4780,  -6.7455,   1.7283,  ..., -11.5951,   2.0453,  -4.8570],
        [ -2.1178,  -5.4750,   1.2461,  ...,  -1.3860,  -3.4950,  -7.7456],
        [  3.1898,  -4.4681,   2.1757,  ...,  -1.0387,  -1.9317,  -3.8632]])

In [78]:
#run inference on test_dataloader
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        predicted = torch.argmax(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

In [81]:
print(correct/total)
print(total)

0.9585253456221198
1736
