In [15]:
import torch, torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.models as tvm
import timm
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

device = "cuda" if torch.cuda.is_available() else "cpu"

In [18]:
IMG = 224
tfm = transforms.Compose([
    transforms.Resize((IMG, IMG)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

In [None]:
#loading the dataset

In [21]:
train_ds = datasets.ImageFolder("/Users/zhang/Desktop/slides/COMP90051/group work/archive/raw", transform=tfm)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=False, num_workers=4)

In [22]:
print("class name:", train_ds.classes)         
print("calss to index:", train_ds.class_to_idx)

class name: ['cane', 'cavallo', 'elefante', 'farfalla', 'gallina', 'gatto', 'mucca', 'pecora', 'ragno', 'scoiattolo']
calss to index: {'cane': 0, 'cavallo': 1, 'elefante': 2, 'farfalla': 3, 'gallina': 4, 'gatto': 5, 'mucca': 6, 'pecora': 7, 'ragno': 8, 'scoiattolo': 9}


In [29]:
#change the whole dataset into training, val, test datasets as ratio 0.7:0.15:0.15
import os, random, shutil
from pathlib import Path
from sklearn.model_selection import train_test_split


SRC_DIR = Path("/Users/zhang/Desktop/slides/COMP90051/group work/archive/raw") 
OUT_DIR = Path("/Users/zhang/Desktop/slides/COMP90051/group work/archive/split")
SPLIT = (0.7, 0.15, 0.15)   
SEED = 42
COPY_MODE = "copy"        

IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}

random.seed(SEED)
OUT_DIR.mkdir(parents=True, exist_ok=True)

def safe_copy(src: Path, dst: Path, mode="copy"):
    dst.parent.mkdir(parents=True, exist_ok=True)
    if mode == "copy":
        shutil.copy2(src, dst)
    elif mode == "move":
        shutil.move(str(src), str(dst))
    elif mode == "link":
        if dst.exists(): return
        os.symlink(src.resolve(), dst)
    else:
        raise ValueError("COPY_MODE must be copy/move/link")


train_r, val_r, test_r = SPLIT
assert abs(train_r + val_r + test_r - 1.0) < 1e-6

class_names = sorted([d.name for d in SRC_DIR.iterdir() if d.is_dir()])
print("共发现类别：", class_names)

for cls in class_names:
    files = [p for p in (SRC_DIR/cls).iterdir() if p.suffix.lower() in IMG_EXTS]
    files = sorted(files)
    if len(files) == 0:
        print(f"[WARN] 类别 {cls} 没有图片，跳过")
        continue


    train_files, temp_files = train_test_split(
        files, test_size=(1 - train_r), random_state=SEED, shuffle=True, stratify=None
    )

    val_size = val_r / (val_r + test_r)
    val_files, test_files = train_test_split(
        temp_files, test_size=(1 - val_size), random_state=SEED, shuffle=True, stratify=None
    )


    for p in train_files:
        safe_copy(p, OUT_DIR/"train"/cls/p.name, COPY_MODE)
    for p in val_files:
        safe_copy(p, OUT_DIR/"val"/cls/p.name, COPY_MODE)
    for p in test_files:
        safe_copy(p, OUT_DIR/"test"/cls/p.name, COPY_MODE)

    print(f"{cls:>12s} -> train:{len(train_files)}  val:{len(val_files)}  test:{len(test_files)}")

print("\n完成！输出目录：", OUT_DIR)


共发现类别： ['cane', 'cavallo', 'elefante', 'farfalla', 'gallina', 'gatto', 'mucca', 'pecora', 'ragno', 'scoiattolo']
        cane -> train:3404  val:729  test:730
     cavallo -> train:1836  val:393  test:394
    elefante -> train:1012  val:217  test:217
    farfalla -> train:1478  val:317  test:317
     gallina -> train:2168  val:465  test:465
       gatto -> train:1167  val:250  test:251
       mucca -> train:1306  val:280  test:280
      pecora -> train:1273  val:273  test:274
       ragno -> train:3374  val:723  test:724
  scoiattolo -> train:1303  val:279  test:280

完成！输出目录： /Users/zhang/Desktop/slides/COMP90051/group work/archive/split


In [31]:
#loading the training, test val datasets
train_ds = datasets.ImageFolder("/Users/zhang/Desktop/slides/COMP90051/group work/archive/split/train", transform=tfm)
test_ds = datasets.ImageFolder("/Users/zhang/Desktop/slides/COMP90051/group work/archive/split/test", transform=tfm)
val_ds = datasets.ImageFolder("/Users/zhang/Desktop/slides/COMP90051/group work/archive/split/val", transform=tfm)

train_dl = DataLoader(train_ds, batch_size=64, shuffle=False, num_workers=4)
test_dl = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=4)
val_dl = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=4)

In [32]:
#check and summarise
print("Train classes:", train_ds.classes)
print("Number of classes:", len(train_ds.classes))
print("Train samples:", len(train_ds))
print("Val samples:", len(val_ds))
print("Test samples:", len(test_ds))

Train classes: ['cane', 'cavallo', 'elefante', 'farfalla', 'gallina', 'gatto', 'mucca', 'pecora', 'ragno', 'scoiattolo']
Number of classes: 10
Train samples: 18321
Val samples: 3926
Test samples: 3932
