In [45]:
import os
import pandas as pd
import random
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from albumentations import Compose, Normalize, Resize
from albumentations import RandomResizedCrop, HorizontalFlip, VerticalFlip, RandomBrightnessContrast
from albumentations.pytorch import ToTensorV2

from torchvision import models

from tqdm import tqdm
import numpy as np
from PIL import Image

import ast


In [None]:
class tokenize_attributes():
    """
    Tokenize attributes based on their type.
    """
    def __init__(self, cameraModelSTxt, CameraMakerTxt):
        self.habitat_types = ['Mixed woodland (with coniferous and deciduous trees)', 'Unmanaged deciduous woodland',
                              'Forest bog', 'coniferous woodland/plantation', 'Deciduous woodland', 'natural grassland', 'lawn',
                              'Unmanaged coniferous woodland', 'garden', 'wooded meadow, grazing forest', 'dune', 'Willow scrubland', 'heath',
                              'Acidic oak woodland', 'roadside', 'Thorny scrubland', 'park/churchyard', 'Bog woodland', 'hedgerow', 'gravel or clay pit',
                              'salt meadow', 'bog', 'meadow', 'improved grassland', 'other habitat', 'roof', 'fallow field', 'ditch', 'fertilized field in rotation']
        
        self.substrate_types = ['soil', 'leaf or needle litter', 'wood chips or mulch', 'dead wood (including bark)', 'bark',
                                'wood', 'bark of living trees', 'mosses', 'wood and roots of living trees', 'stems of herbs, grass etc',
                                'peat mosses','dead stems of herbs, grass etc', 'fungi', 'other substrate', 'living stems of herbs, grass etc',
                                'living leaves', 'fire spot', 'faeces', 'cones', 'fruits']
        #load the txt files with camera models and camera makers, the text file is already on the form of a list of strings
        with open(cameraModelSTxt, "r", encoding="utf-8") as f:
            camera_models_types = f.read()
        self.camera_models_types = ast.literal_eval(camera_models_types)
        with open(CameraMakerTxt, "r", encoding="utf-8") as f:
            camera_makers_types = f.read()
        self.camera_makers_types = ast.literal_eval(camera_makers_types)
    
    
        self.habitat_types2idx = {habitat: idx for idx, habitat in enumerate(self.habitat_types)}
        self.substrate_types2idx = {substrate: idx for idx, substrate in enumerate(self.substrate_types)}
        self.camera_models2idx = {model: idx for idx, model in enumerate(self.camera_models_types)}
        self.camera_makers2idx = {maker: idx for idx, maker in enumerate(self.camera_makers_types)}
        
        self.num_habitats = len(self.habitat_types) + 1  # +1 for 'missing habitat'
        self.num_substrates = len(self.substrate_types) + 1  # +1 for 'missing substrate'
        self.num_months = 12+1
        self.num_hours = 24 # 0-23, +1 for 'missing hour'
        self.num_camera_models = len(self.camera_models_types) + 1  # +1 for 'missing camera model'
        self.num_camera_makers = len(self.camera_makers_types) + 1  # +1 for 'missing camera maker'
        
        
    
    def tokenize(self, attribute, attribute_type):
        if attribute_type == 'Habitat':
            if attribute not in self.habitat_types:
                return len(self.habitat_types)  # Return index for 'missing habitat'
            else:
                return self.habitat_types2idx[attribute]
        elif attribute_type == 'Substrate':
            if attribute not in self.substrate_types:
                return len(self.substrate_types)
            else:
                return self.substrate_types2idx[attribute]
        elif attribute_type == 'DateTimeOriginal':
            try:
                # Expecting 'yyyy-mm-dd' format
                yymmdd = attribute.split(' ')[0]
                hhmmss = attribute.split(' ')[1]
                month = int(yymmdd.split(':')[1])  # Extract month (1-12)
                hour = int(hhmmss.split(':')[0])  # Extract hour (0-23)
                
                return month-1, hour  # Return month index (0-11)
            except (IndexError, ValueError, AttributeError):
                return self.num_months-1, self.num_hours  # or a default month token, e.g., 0
        elif attribute_type == 'camera_model':
            if attribute not in self.camera_models_types:
                return len(self.camera_models_types)
            else:
                return self.camera_models2idx[attribute]
        elif attribute_type == 'camera_maker':
            if attribute not in self.camera_makers_types:
                return len(self.camera_makers_types)
            else:
                return self.camera_makers2idx[attribute]

In [47]:
class FungiDataset(Dataset):
    def __init__(self, df, path, CameraMakerTxt, cameraModelSTxt, transform=None, multi_modal=False):
        self.df = df
        self.transform = transform
        self.path = path
        self.multi_modal = multi_modal
        self.tokenizer = tokenize_attributes(CameraMakerTxt, cameraModelSTxt)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_path = self.df['filename_index'].values[idx]
        # Get label if it exists; otherwise return None
        label = self.df['taxonID_index'].values[idx]  # Get label
        if pd.isnull(label):
            label = -1  # Handle missing labels for the test dataset
        else:
            label = int(label)
            
        if self.multi_modal:
            habitat = self.df['Habitat'].values[idx]
            if pd.isnull(habitat):
                habitat = -1
            else:
                habitat = str(habitat)
            habitat = self.tokenizer.tokenize(habitat, 'Habitat')
                
            latitude = self.df['Latitude'].values[idx]
            if pd.isnull(latitude):
                latitude = -1
            else:
                latitude = float(latitude)
            longitude = self.df['Longitude'].values[idx]
            if pd.isnull(longitude):
                longitude = -1
            else:
                longitude = float(longitude)
            substrate = self.df['Substrate'].values[idx]
            if pd.isnull(substrate):
                substrate = -1
            else:
                substrate = str(substrate)
            substrate = self.tokenizer.tokenize(substrate, 'Substrate')
            eventDate = self.df['DateTimeOriginal'].values[idx]
            if pd.isnull(eventDate):
                month, hour = -1, -1
            else:
                eventDate = str(eventDate)
            month, hour = self.tokenizer.tokenize(eventDate, 'DateTimeOriginal')
            cameraModel = self.df['camera_model'].values[idx]
            if pd.isnull(cameraModel):
                cameraModel = -1
            else:
                cameraModel = str(cameraModel)
            cameraModel = self.tokenizer.tokenize(cameraModel, 'camera_model')
            cameraMaker = self.df['camera_maker'].values[idx]
            if pd.isnull(cameraMaker):
                cameraMaker = -1
            else:
                cameraMaker = str(cameraMaker)
            cameraMaker = self.tokenizer.tokenize(cameraMaker, 'camera_maker')
            
        with Image.open(os.path.join(self.path, file_path)) as img:
            # Convert to RGB mode (handles grayscale images as well)
            image = img.convert('RGB')
        image = np.array(image)

        # Apply transformations if available
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        if self.multi_modal:
            return image, label, file_path, habitat, substrate, month, hour, cameraMaker, cameraModel, latitude, longitude
        else:
            return image, label, file_path

def get_transforms(data):
    """
    Return augmentation transforms for the specified mode ('train' or 'valid').
    """
    width, height = 224, 224
    if data == 'train':
        return Compose([
            RandomResizedCrop(size = (width, height), scale=(0.8, 1.0)),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            RandomBrightnessContrast(p=0.2),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    elif data == 'valid':
        return Compose([
            Resize(width, height),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        raise ValueError("Unknown data mode requested (only 'train' or 'valid' allowed).")

In [48]:
data_file = str('C:/Users/bmsha/sc2025/metadata_1/metadata_with_camera_info.csv')
df = pd.read_csv(data_file)
train_df = df[df['filename_index'].str.startswith('fungi_train')]
image_path = str("C:/Users/bmsha/sc2025/FungiImages")
cameraModelSTxt = r"C:\Users\bmsha\sc2025\metadata_1\camera_models.txt"
CameraMakerTxt = r"C:\Users\bmsha\sc2025\metadata_1\camera_makers.txt"

train_dataset = FungiDataset(train_df, image_path, cameraModelSTxt, CameraMakerTxt, transform=get_transforms(data='train'), multi_modal=True)

In [49]:
#iterate through the dataset
for i in tqdm(range(len(train_dataset))):
    image, label, _, habitat, substrate, eventDate, latitude, longitude = train_dataset[i]
    print(f"Image shape: {image.shape}, Label: {label}, File path: {file_path}, Habitat: {habitat}, Substrate: {substrate}, Event Date: {eventDate}, Latitude: {latitude}, Longitude: {longitude}")
    # You can add more processing or checks here as needed

  0%|          | 0/25863 [01:49<?, ?it/s]


ValueError: too many values to unpack (expected 8)