<a href="https://colab.research.google.com/github/EmilisEm/gmm/blob/master/lab1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Author LSP: 2213748

#### Author: Emilis Kleinas

#### Variant: Use of `ResNet` model with `Printer`, `Torch` and `Cello` classes

#### What does the program do

1. Downloads images for the specified classes from `OpenImages`
2. Processes the images with the `ResNet` model
3. Calculates precision, accuracy, recall and F1 statistics for the downloaded images


## 1. Mount google drive to access and store images

In [None]:
from google.colab import drive
drive_base_uri = '/content/drive'
drive.mount(drive_base_uri)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
from pathlib import Path
from tqdm import tqdm
import PIL.Image
import torchvision.transforms as transforms

In [None]:
# The `data_dir` variable specifies the base directory to which the images will be saves.
# The images are stored as follows `{data_dir}/{class_name}/images/*.jpg`
# Where `class_name` is the name of the class being processed (e.g. "cow") in lowercase
data_dir = drive_base_uri + "/MyDrive/openimages"
number_of_images = 100
classes = ["Torch", "Cello", "Printer"]

## Download images for data classes (Optional if images already downloaded)

In [None]:
!pip install openimages



In [None]:
from openimages.download import download_dataset
download_dataset(data_dir, classes, limit=number_of_images)

100%|██████████| 10/10 [00:01<00:00,  5.43it/s]
100%|██████████| 10/10 [00:01<00:00,  7.12it/s]
100%|██████████| 10/10 [00:01<00:00,  7.60it/s]


{'torch': {'images_dir': '/content/drive/MyDrive/openimages/torch/images'},
 'cello': {'images_dir': '/content/drive/MyDrive/openimages/cello/images'},
 'printer': {'images_dir': '/content/drive/MyDrive/openimages/printer/images'}}

## Define custom `DataSet` class

In [None]:
class CustomDataset(Dataset):
    def __init__(self, base_dir, transform):
        self.transform = transform
        self.samples = []

        for dataset_class in classes:
            class_path = Path(data_dir) / dataset_class.lower() / "images"
            if class_path.exists():
                self.samples.extend([(str(p), dataset_class) for p in class_path.glob('*.jpg')])

        print(f"Found {len(self.samples)} images")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, dataset_class = self.samples[idx]
        try:
            with PIL.Image.open(img_path) as img:
                img = img.convert('RGB')
                if self.transform:
                    img = self.transform(img)
                return img, 0, dataset_class
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            return torch.zeros((3, 224, 224)), 0, dataset_class

## Initialize pretrained model `ResNet50`

In [None]:
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
preprocess = weights.transforms()

class_indices = {
    class_name: weights.meta["categories"].index(class_name.lower())
    for class_name in classes
  }

print("Model initialized with class indices:")
for cls, idx in class_indices.items():
    print(f"{cls}: {idx}")


Model initialized with class indices:
Torch: 862
Cello: 486
Printer: 742


## Initialize data loader

In [None]:
dataset = CustomDataset(data_dir, transform=preprocess)
data_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=False,
    num_workers=1,
    prefetch_factor=2,
    persistent_workers=True,
    multiprocessing_context='fork'
)

print(f"Dataset size: {len(dataset)} images")
print(f"Using {data_loader.num_workers} workers")
print(f"Using GPU: {torch.cuda.is_available()}")


Found 30 images
Dataset size: 30 images
Using 1 workers
Using GPU: False


## Apply image batches to model and extraxt results

> Add blockquote



In [None]:
all_predictions = []
all_true_classes = []

model.eval()
with torch.no_grad():
    for batch_images, _, batch_classes in tqdm(data_loader):
        if torch.cuda.is_available():
            batch_images = batch_images.cuda()
            model = model.cuda()
        predictions = model(batch_images).softmax(dim=1)
        predictions = predictions.cpu()
        all_predictions.append(predictions)
        all_true_classes.extend(batch_classes)
all_predictions = torch.cat(all_predictions, dim=0)

results = {
    'predictions': all_predictions,
    'true_classes': all_true_classes,
    'class_indices': class_indices
}

100%|██████████| 1/1 [00:13<00:00, 13.40s/it]


## Calculate metrics based on results returned by model

In [None]:
def calculate_metrics(threshold):
    predictions = results['predictions']
    true_classes = results['true_classes']
    class_indices = results['class_indices']

    metrics = {}

    for cls in classes:
        idx = class_indices[cls]
        pred_values = predictions[:, idx]
        pred_binary = (pred_values > threshold)

        is_current_class = torch.tensor([label == cls for label in true_classes])

        tp = torch.sum((pred_binary) & (is_current_class)).float()
        fp = torch.sum((pred_binary) & (~is_current_class)).float()
        tn = torch.sum((~pred_binary) & (~is_current_class)).float()
        fn = torch.sum((~pred_binary) & (is_current_class)).float()

        total = float(len(true_classes))

        precision = (tp / (tp + fp)).item() if (tp + fp) > 0 else 0
        recall = (tp / (tp + fn)).item() if (tp + fn) > 0 else 0
        accuracy = (tp + tn) / total
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

        metrics[cls] = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1
        }

    return metrics


In [None]:
import ipywidgets as widgets
from IPython.display import display

def update_threshold(threshold):
    metrics = calculate_metrics(threshold)
    print(f"\nResults with threshold {threshold}:")

    for cls, cls_metrics in metrics.items():
        print(f"\n{cls.capitalize()}:")
        for metric, value in cls_metrics.items():
            print(f"{metric}: {value:.3f}")

threshold_slider = widgets.FloatSlider(
    value=0.5,
    min=0.0,
    max=1.0,
    step=0.01,
    description="Threshold:"
)

widgets.interactive(update_threshold, threshold=threshold_slider)

interactive(children=(FloatSlider(value=0.5, description='Threshold:', max=1.0, step=0.01), Output()), _dom_cl…