In [1]:
import os
import pandas as pd
from PIL import Image
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from pyspark.sql import SparkSession


In [2]:
spark = SparkSession.builder \
    .appName("Wildlens SQL Loader") \
    .config("spark.jars", "ETL/installation/mysql-connector-j-9.1.0.jar") \
    .getOrCreate()

images_df = spark.read \
    .format("jdbc") \
    .option("url", "jdbc:mysql://localhost:3306/wildlens") \
    .option("driver", "com.mysql.cj.jdbc.Driver") \
    .option("dbtable", "wildlens_images") \
    .option("user", "root") \
    .option("password", "root") \
    .load()

facts_df = spark.read \
    .format("jdbc") \
    .option("url", "jdbc:mysql://localhost:3306/wildlens") \
    .option("driver", "com.mysql.cj.jdbc.Driver") \
    .option("dbtable", "wildlens_facts") \
    .option("user", "root") \
    .option("password", "root") \
    .load()

joined_df = images_df.join(facts_df, images_df.id_espece == facts_df.id_espece) \
                     .select("image", "nom_fr", "id_etat")

df = joined_df.toPandas()


25/06/14 21:19:01 WARN Utils: Your hostname, DESKTOP-V5T7J8A resolves to a loopback address: 127.0.1.1; using 172.25.127.84 instead (on interface eth0)
25/06/14 21:19:01 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
25/06/14 21:19:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
                                                                                

In [3]:
train_df = df[df['id_etat'] == 1]
val_df = df[df['id_etat'] == 2]
test_df = df[df['id_etat'] == 3]

In [4]:
class WildlensDataset(Dataset):
    def __init__(self, dataframe, root_dir, transform=None, label2idx=None):
        self.data = dataframe.reset_index(drop=True)
        self.root_dir = root_dir
        self.transform = transform
        self.label2idx = label2idx or {label: idx for idx, label in enumerate(sorted(dataframe['nom_fr'].unique()))}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.data.iloc[idx]['image'])
        image = Image.open(img_path).convert('RGB')
        label = self.label2idx[self.data.iloc[idx]['nom_fr']]
        if self.transform:
            image = self.transform(image)
        return image, label

In [5]:
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

In [6]:
val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [7]:
label2idx = {label: idx for idx, label in enumerate(sorted(df['nom_fr'].unique()))}

train_dataset = WildlensDataset(train_df, 'ressource/image/augmented_train', transform=train_transforms, label2idx=label2idx)
val_dataset = WildlensDataset(val_df, 'ressource/image/augmented_train', transform=val_transforms, label2idx=label2idx)
test_dataset = WildlensDataset(test_df, 'ressource/image/augmented_train', transform=val_transforms, label2idx=label2idx)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)
