In [5]:
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset
import glob
import torch
import torchvision
from torchvision.transforms import v2
from PIL import Image
import json
import numpy as np

class ImageNette(Dataset):
    def __init__(self, base_dir=None):
        self.base_dir =  "/home/malapati/Storage/diffusion-detection/AlignProp/data/imagenette2/val"
        self.classes = os.listdir(self.base_dir)
        self.image_paths,self.labels = [], []
        self.process_label_map()

        for i,cls in enumerate(self.classes):
            for path in glob.glob(os.path.join(self.base_dir,cls+"/*")):
                self.image_paths.append(path)
                self.labels.append(self.idx2label[cls])
        self.transforms = torchvision.transforms.Compose([
                v2.ToTensor(),
                v2.ToDtype(torch.float32),
                v2.Resize((224,224)),
                torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ]
        )
    
    def process_label_map(self,path="/home/malapati/Storage/diffusion-detection/AlignProp/data/imagenette2/imagenet_class_index.json"):
        self.class_idx = json.load(open(path,'r'))
        self.idx2label = {}
        for k in range(len(self.class_idx)):
            self.idx2label[self.class_idx[str(k)][0]] = k
        

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = np.array(Image.open(self.image_paths[idx]).convert('RGB'))
        label = self.labels[idx]
        if self.transforms:
            image = self.transforms(image)
        return image, label

In [6]:
dataset = ImageNette()



In [7]:
from torchvision import models
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import json


classifier = models.resnet18(pretrained=True)
dataloader = DataLoader(dataset,batch_size=8)

loss_fn = torch.nn.CrossEntropyLoss()
acc = 0
final_loss = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
classifier.to(device)
classifier.eval()
for image,label in tqdm(dataloader):
    image = image.to(device)
    label = label.to(device)
    pred = classifier(image)
    pred_cls = torch.argmax(pred,axis=1)
    loss = loss_fn(pred,label).float().detach().cpu().numpy()

    final_loss.append(loss)
    acc += (pred_cls==label).sum().float().detach().cpu().numpy()

print(np.mean(final_loss),acc/len(dataset))

#classifier.fc = torch.nn.Linear(classifier.fc.in_features, config.num_classes)

100%|██████████| 491/491 [00:25<00:00, 19.47it/s]

0.85545194 0.7745222929936306



