In [1]:
from matplotlib import pyplot as plt
from huggingface_hub import hf_hub_download
from datasets import load_dataset
from sklearn.metrics import f1_score
import pandas as pd

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import resnet18, swin_v2_t
from tqdm import tqdm

In [2]:
ds = load_dataset(
    "mikkoim/aquamonitor-jyu",
    cache_dir="/kaggle/working/"
)

hf_hub_download(
    repo_id="mikkoim/aquamonitor-jyu",
    filename="aquamonitor-jyu.parquet.gzip",
    repo_type="dataset",
    local_dir="/kaggle/working/"
)

README.md:   0%|          | 0.00/5.73k [00:00<?, ?B/s]

train.tar:   0%|          | 0.00/143M [00:00<?, ?B/s]

val.tar:   0%|          | 0.00/22.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/40880 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/6394 [00:00<?, ? examples/s]

aquamonitor-jyu.parquet.gzip:   0%|          | 0.00/1.30M [00:00<?, ?B/s]

'/kaggle/working/aquamonitor-jyu.parquet.gzip'

In [3]:
# The keys match the rows in the metadata table
metadata = pd.read_parquet("/kaggle/working/aquamonitor-jyu.parquet.gzip")

classes = sorted(metadata["taxon_group"].unique())
class_map = {k:v for v,k in enumerate(classes)}
class_map_inv = {v:k for k,v in class_map.items()}

metadata["img"] = metadata["img"].str.removesuffix(".jpg")
label_dict = dict(zip(metadata["img"], metadata["taxon_group"].map(class_map)))

In [4]:
IMAGE_SIZE = 224
BATCH_SIZE = 16

tf = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def preprocess(batch):
    return {"key": batch["__key__"],
            "img": [tf(x) for x in batch["jpg"]],
            "label": torch.as_tensor([label_dict[x] for x in batch["__key__"]], dtype=torch.long)}


eval_ds = ds["validation"].with_transform(preprocess)

print(f"Development Size: {eval_ds.num_rows}")

eval_loader = DataLoader(
    eval_ds,
    batch_size=BATCH_SIZE
)

Development Size: 6394


In [5]:
class EnsembleModel(nn.Module):
    def __init__(self, class_num):
        super(EnsembleModel, self).__init__()

        resnet = resnet18()
        resnet.fc = nn.Linear(resnet.fc.in_features, class_num)

        swin = swin_v2_t()
        swin.head = nn.Linear(swin.head.in_features, class_num)
        
        self.resnet = resnet
        self.swin = swin
    
    def forward(self, x):
        
        resnet_out = self.resnet(x)
        swin_out = self.swin(x)

        out = (resnet_out + swin_out) / 2.0
        return out

model = EnsembleModel(len(classes))

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight_dict = torch.load("/kaggle/input/aquaensemble/model.pt", weights_only=True, map_location=device)

model = EnsembleModel(len(classes))

model.resnet.load_state_dict(weight_dict["resnet"])
model.swin.load_state_dict(weight_dict["swin"])

<All keys matched successfully>

In [7]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model.resnet.load_state_dict(torch.load("/kaggle/input/resnet/pytorch/default/1/fine_tuned_resnet18.pth", map_location=device))
# model.swin.load_state_dict(torch.load("/kaggle/input/swine/pytorch/default/1/model.pth", map_location=device))

In [8]:
model.to(device)
model.eval()

eval_labels = []
eval_preds = []

with torch.no_grad():
    for batch in tqdm(eval_loader):
        images, labels = batch["img"], batch["label"]
        images, labels = images.to(device), labels.to(device)

        out = model(images)

        # Labels und Predictions sammeln
        _, preds = torch.max(out.data, 1)
        eval_labels.extend(labels.cpu().numpy())
        eval_preds.extend(preds.cpu().numpy())

eval_f1 = f1_score(eval_labels, eval_preds, average="weighted")
print(f"F1-Score der kombinierten Vorhersagen auf dem Validierungsset: {eval_f1:.3f}")

100%|██████████| 400/400 [00:37<00:00, 10.58it/s]

F1-Score der kombinierten Vorhersagen auf dem Validierungsset: 0.807





In [9]:
torch.save(
    {"resnet": model.resnet.state_dict(), "swin": model.swin.state_dict()},
    "/kaggle/working/model.pt"
)