In [1]:
import torch, torchvision
import wandb
import datasets

from utils.model_utils import load_model


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ds = datasets.load_dataset('Nech-C/mineralimage5K-98')

In [6]:
ds

DatasetDict({
    train: Dataset({
        features: ['image', 'name', 'description', 'mineral_boxes'],
        num_rows: 12828
    })
    validation: Dataset({
        features: ['image', 'name', 'description', 'mineral_boxes'],
        num_rows: 2749
    })
    test: Dataset({
        features: ['image', 'name', 'description', 'mineral_boxes'],
        num_rows: 2749
    })
})

In [3]:
train_ds = ds['train']
test_ds = ds['test']
val_ds = ds['validation']

In [4]:
from utils.image_preprocess import train_preprocess, preprocess

In [6]:
train_ds[0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x831>,
 'name': 9,
 'description': 'graphic pegmatite.\n\n12x8.5 cm.',
 'mineral_boxes': [{'box': [0.15918, 0.06859, 0.89551, 0.90734],
   'confidence': 0.234,
   'label': 'a stone'}]}

In [5]:
train_ds.set_transform(train_preprocess)
test_ds.set_transform(preprocess)
val_ds.set_transform(preprocess)

In [8]:
train_ds[0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x831>,
 'name': 9,
 'description': 'graphic pegmatite.\n\n12x8.5 cm.',
 'mineral_boxes': [{'box': [0.15918, 0.06859, 0.89551, 0.90734],
   'confidence': 0.234,
   'label': 'a stone'}],
 'pixel_values': tensor([[[-1.9809, -1.9809, -1.9809,  ..., -1.9809, -1.9809, -1.9809],
          [-1.9809, -1.9809, -1.9809,  ..., -1.9809, -1.9809, -1.9809],
          [-1.9809, -1.9809, -1.9809,  ..., -1.9809, -1.9809, -1.9809],
          ...,
          [ 0.7419,  0.6734,  0.6221,  ..., -0.1143, -0.1314, -0.1314],
          [ 0.8276,  0.6906,  0.7933,  ..., -0.2171, -0.1657, -0.2513],
          [ 0.9474,  0.7762,  0.8618,  ..., -0.3541, -0.1486, -0.2513]],
 
         [[-1.8957, -1.8957, -1.8957,  ..., -1.8957, -1.8957, -1.8957],
          [-1.8957, -1.8957, -1.8957,  ..., -1.8957, -1.8957, -1.8957],
          [-1.8957, -1.8957, -1.8957,  ..., -1.8957, -1.8957, -1.8957],
          ...,
          [ 1.1856,  1.1331,  1.0630,  ...,  0.36

In [6]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["name"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [7]:
import numpy as np
import evaluate
accuracy = evaluate.load("accuracy")

# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred["predictions"], axis=1)
    return accuracy.compute(predictions=predictions, references=eval_pred["label_ids"])

In [20]:
batch_size = 256
lr_rate = 0.0008
num_epoch = 25
weight_decay = 0.02
label_smoothing_factor = 0.1

In [21]:
train_data_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
test_data_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
val_data_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)

In [None]:
import torch
import wandb

model = load_model('./configs/models/resnext101_32x8d_ver1.0.0.toml')

device = 'cuda' if torch.cuda.is_available() else 'cpu'

wandb.init(
    project='mineral-net',
    config={
        "learning_rate": lr_rate,
        "num_epoch": num_epoch,
        "batch_size": batch_size,
        "weight_decay": weight_decay,
    }
)

artifact = wandb.Artifact("config", type="config")
artifact.add_file("./configs/models/resnext101_32x8d_ver1.0.0.toml")
wandb.log_artifact(artifact)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr_rate, weight_decay=weight_decay)

import torch.optim.lr_scheduler as lr_scheduler
scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

model.to(device)
for epoch in range(num_epoch):
    model.train()
    for batch in train_data_loader:
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(pixel_values)
        loss = torch.nn.functional.cross_entropy(outputs, labels, label_smoothing=label_smoothing_factor)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()   
        wandb.log({"train_loss": loss.item()})
    scheduler.step()
    
    with torch.no_grad():
        model.eval()
        tot_loss = 0
        for batch in val_data_loader:
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(pixel_values)
            loss = torch.nn.functional.cross_entropy(outputs, labels)
            tot_loss += loss.item()
            
        wandb.log({"val_loss": tot_loss/len(val_data_loader)})

    wandb.log({"epoch": epoch})

# Final testing phase with metrics
with torch.no_grad():
    model.eval()
    tot_loss = 0
    all_preds = []
    all_labels = []
    for batch in test_data_loader:
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(pixel_values)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        tot_loss += loss.item()
        all_preds.append(outputs.cpu())
        all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    wandb.log({"test_loss": tot_loss/len(test_data_loader)})
    wandb.log(compute_metrics({"predictions": all_preds, "label_ids": all_labels}))

wandb.finish()


Using cache found in C:\Users\Nech/.cache\torch\hub\pytorch_vision_main


0,1
train_loss,█▆▅▃▃▁

0,1
train_loss,4.25416


In [53]:
wandb.log(compute_metrics({"predictions": outputs.cpu(), "label_ids": labels.cpu()}))

In [None]:
torch.save(model.state_dict(), '101_2.pth')

In [21]:
wandb.finish()

0,1
train_loss,▁

0,1
train_loss,91.79338
