In [90]:
import torch
from torchvision import models, transforms
from PIL import Image
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import os
import pandas as pd
from tqdm import tqdm



model = models.resnet50(pretrained=False)
num_ftrs = model.fc.in_features
num_classes = 17
model.fc = torch.nn.Linear(num_ftrs, num_classes)
model_path = 'model_2024-08-07.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

checkpoint = torch.load(model_path)

model.load_state_dict(checkpoint['model_state_dict'])





<All keys matched successfully>

In [91]:
model = model.to(device)


In [92]:
model.eval()


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [93]:
preds_list = []


In [94]:
class ImageDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.df = pd.read_csv(csv_file)  # Assuming CSV file has image paths and labels
        self.root_dir = root_dir
        self.transform = transform
        self.img_labels = self.df['ID'].tolist()  # Extract image IDs or paths
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.df.iloc[idx, 0])
        image = Image.open(img_name)
        label = self.df.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        return image, label


In [95]:
tst_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),  # Ensures the image is a FloatTensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [96]:
tst_dataset = ImageDataset(
    "data/sample_submission.csv",
    "data/test/",
    transform=tst_transform
)

In [97]:
tst_loader = DataLoader(
    tst_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

In [98]:
preds_list = []


In [99]:
for images, _ in tqdm(tst_loader):
    images = images.to(device)  # Ensure images are on the same device as the model
    
    with torch.no_grad():
        preds = model(images)  # Forward pass
    
    preds_list.extend(preds.argmax(dim=1).cpu().numpy())

100%|██████████| 99/99 [00:08<00:00, 11.89it/s]


In [100]:
pred_df = pd.DataFrame(tst_dataset.df, columns=['ID', 'target'])
pred_df['target'] = preds_list

In [101]:
sample_submission_df = pd.read_csv("data/sample_submission.csv")
assert (sample_submission_df['ID'] == pred_df['ID']).all()

In [102]:
pred_df.to_csv("pred.csv", index=False)

In [103]:
pred_df.head()

Unnamed: 0,ID,target
0,0008fdb22ddce0ce.jpg,2
1,00091bffdffd83de.jpg,6
2,00396fbc1f6cc21d.jpg,8
3,00471f8038d9c4b6.jpg,13
4,00901f504008d884.jpg,2


In [None]:
pred_df = pd.read_csv("pred.csv")

image_dir = 'data/test/'
fig, axes = plt.subplots(1, 100, figsize=(15, 5))
