In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data.dataset import random_split 
from transformers import BeitImageProcessor, BeitForImageClassification, Trainer, AutoFeatureExtractor, TrainingArguments
from torch.utils.data import TensorDataset
from datasets import load_dataset, load_from_disk, Dataset
import torch.optim as optim
import torch.nn as nn
import json
import pandas as pd
import numpy as np
import transformers
import evaluate
import huggingface_hub
from transformers import AutoFeatureExtractor
from evaluate import evaluator

In [3]:
huggingface_hub.notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
#Initialise Cuda and check that Cuda is available
device = torch.device("cuda")
print(device)
print(torch.cuda.is_available())

cuda
True


In [5]:
processor = BeitImageProcessor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
feature_extractor = AutoFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')



In [6]:
feature_extractor = AutoFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')

print(feature_extractor.size)

{'height': 224, 'width': 224}


In [7]:
#Load Dataset
ds = load_dataset("./dataset")


Resolving data files:   0%|          | 0/6862 [00:00<?, ?it/s]

Found cached dataset imagefolder (/home/felixmorgan/.cache/huggingface/datasets/imagefolder/dataset-69da99f399a4f097/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)


  0%|          | 0/1 [00:00<?, ?it/s]

In [8]:
shuffled = ds.shuffle(seed=42)

Loading cached shuffled indices for dataset at /home/felixmorgan/.cache/huggingface/datasets/imagefolder/dataset-69da99f399a4f097/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f/cache-714093f88edeae8e.arrow


In [9]:
labels = pd.Series(shuffled['train']['label'])

print(labels.value_counts())

8     1160
1      851
0      698
9      692
3      639
10     621
4      591
6      526
2      475
5      377
7      232
dtype: int64


In [10]:
x = list(labels.value_counts().loc[lambda x : x>90].keys())
include_index = [i for i, j in enumerate(labels) if j in x]
include = [j for i, j in enumerate(labels) if j in x]

In [11]:
print(len(set(x)))
print(set(x))

11
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}


In [12]:
def select_indexes(lst, lst_index):
    idx_dict = {}
    for i, val in enumerate(lst):
        if val not in idx_dict:
            idx_dict[val] = []
        elif len(idx_dict[val]) < 500:
            idx_dict[val].append(lst_index[i])
        elif all(len(v) == 500 for v in idx_dict.values()):
            break
    return idx_dict

sample = []
for key, value in select_indexes(include, include_index).items():
    print(key, len(value))
    sample += value

print(len(sample))
print(sample[:10])

9 500
10 500
6 500
4 500
3 500
8 500
2 474
0 500
5 376
1 500
7 231
5081
[29, 30, 31, 54, 88, 93, 105, 106, 107, 113]


In [72]:
sampled = shuffled['train'].select(sample)

In [73]:
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
train_transforms = Compose(
        [
            RandomResizedCrop([224, 224]),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

val_transforms = Compose(
        [
            Resize([224, 224]),
            CenterCrop([224, 224]),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

In [74]:
model_ds = shuffled['train'].train_test_split(test_size=0.2)

In [75]:
labels = model_ds["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

id2label[2]

'frost'

In [76]:
print(model_ds)

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 5489
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 1373
    })
})


In [77]:
train_ds = model_ds['train']
val_split = model_ds['test'].train_test_split(0.5)
val_ds = val_split['train']
test_ds = val_split['test']

In [79]:
train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)
test_ds.set_transform(preprocess_val)

In [80]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}
    

In [81]:
metric = evaluate.load("accuracy")

def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)


In [86]:
training_args = TrainingArguments(
    output_dir="./weather-base",
    remove_unused_columns=False,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=8,
    num_train_epochs=6,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [87]:
model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k', num_labels=11,ignore_mismatched_sizes=True, label2id=label2id,
    id2label=id2label,)

loading configuration file config.json from cache at /home/felixmorgan/.cache/huggingface/hub/models--microsoft--beit-base-patch16-224-pt22k-ft22k/snapshots/9da301148150e37e533abef672062fa49f6bda4f/config.json
Model config BeitConfig {
  "architectures": [
    "BeitForImageClassification"
  ],
  "attention_probs_dropout_prob": 0.0,
  "auxiliary_channels": 256,
  "auxiliary_concat_input": false,
  "auxiliary_loss_weight": 0.4,
  "auxiliary_num_convs": 1,
  "drop_path_rate": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "id2label": {
    "0": "dew",
    "1": "fogsmog",
    "2": "frost",
    "3": "glaze",
    "4": "hail",
    "5": "lightning",
    "6": "rain",
    "7": "rainbow",
    "8": "rime",
    "9": "sandstorm",
    "10": "snow"
  },
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "dew": 0,
    "fogsmog": 1,
    "frost": 2,
    "glaze": 3,
    "hail": 4,
    "lightning": 5,
    "rain": 6,
    "ra

In [88]:
trainer = Trainer(
    model = model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,

)

/home/felixmorgan/PycharmProjects/pythonProject/./weather-base is already a clone of https://huggingface.co/ChasingMercer/weather-base. Make sure you pull the latest changes with `repo.git_pull()`.


In [89]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

***** Running training *****
  Num examples = 5489
  Num Epochs = 6
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 4
  Total optimization steps = 1026
  Number of trainable parameters = 85770443


Epoch,Training Loss,Validation Loss,Accuracy
0,0.3368,0.277989,0.900875
1,0.2129,0.233291,0.930029
2,0.1827,0.244026,0.921283
3,0.1475,0.230571,0.931487
4,0.1284,0.219156,0.93586
5,0.0526,0.218415,0.93586


***** Running Evaluation *****
  Num examples = 686
  Batch size = 8
Saving model checkpoint to ./weather-base/checkpoint-171
Configuration saved in ./weather-base/checkpoint-171/config.json
Model weights saved in ./weather-base/checkpoint-171/pytorch_model.bin
Image processor saved in ./weather-base/checkpoint-171/preprocessor_config.json
Image processor saved in ./weather-base/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 686
  Batch size = 8
Saving model checkpoint to ./weather-base/checkpoint-342
Configuration saved in ./weather-base/checkpoint-342/config.json
Model weights saved in ./weather-base/checkpoint-342/pytorch_model.bin
Image processor saved in ./weather-base/checkpoint-342/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 686
  Batch size = 8
Saving model checkpoint to ./weather-base/checkpoint-513
Configuration saved in ./weather-base/checkpoint-513/config.json
Model weights saved in ./weather-base/checkpoint-513/pytorch

Upload file pytorch_model.bin:   0%|          | 32.0k/331M [00:00<?, ?B/s]

Upload file runs/Mar07_14-11-04_pop-os/events.out.tfevents.1678198277.pop-os.7668.5: 100%|##########| 22.5k/22…

remote: Scanning LFS files of refs/heads/main for validity...        
remote: LFS file scan complete.        
To https://huggingface.co/ChasingMercer/weather-base
   bcbb22e..aae726f  main -> main

To https://huggingface.co/ChasingMercer/weather-base
   aae726f..e9efa32  main -> main



***** train metrics *****
  epoch                    =          6.0
  total_flos               = 2374797442GF
  train_loss               =        0.292
  train_runtime            =   0:40:02.12
  train_samples_per_second =        13.71
  train_steps_per_second   =        0.427


In [92]:
print(train_results.metrics)

{'train_runtime': 2402.1234, 'train_samples_per_second': 13.71, 'train_steps_per_second': 0.427, 'total_flos': 2.549919337377528e+18, 'train_loss': 0.2920494159759834, 'epoch': 6.0}


In [93]:
evluator_model = BeitForImageClassification.from_pretrained('weather-base/checkpoint-855', num_labels=11,ignore_mismatched_sizes=True, label2id=label2id,
    id2label=id2label)

loading configuration file weather-base/checkpoint-855/config.json
Model config BeitConfig {
  "_name_or_path": "microsoft/beit-base-patch16-224-pt22k-ft22k",
  "architectures": [
    "BeitForImageClassification"
  ],
  "attention_probs_dropout_prob": 0.0,
  "auxiliary_channels": 256,
  "auxiliary_concat_input": false,
  "auxiliary_loss_weight": 0.4,
  "auxiliary_num_convs": 1,
  "drop_path_rate": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "id2label": {
    "0": "dew",
    "1": "fogsmog",
    "2": "frost",
    "3": "glaze",
    "4": "hail",
    "5": "lightning",
    "6": "rain",
    "7": "rainbow",
    "8": "rime",
    "9": "sandstorm",
    "10": "snow"
  },
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "dew": 0,
    "fogsmog": 1,
    "frost": 2,
    "glaze": 3,
    "hail": 4,
    "lightning": 5,
    "rain": 6,
    "rainbow": 7,
    "rime": 8,
    "sandstorm": 9,
    "snow": 10
  },
  "layer_no

In [94]:
trainer = Trainer(
    model = evaluator_model,
    args=training_args,
    train_dataset=model_ds["train"],
    eval_dataset=model_ds["test"],
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,

)

/home/felixmorgan/PycharmProjects/pythonProject/./weather-base is already a clone of https://huggingface.co/ChasingMercer/weather-base. Make sure you pull the latest changes with `repo.git_pull()`.


In [95]:
evaluator = trainer.evaluate(test_ds)

print(evaluator)

***** Running Evaluation *****
  Num examples = 687
  Batch size = 8


{'eval_loss': 0.13604497909545898, 'eval_accuracy': 0.9534206695778749, 'eval_runtime': 16.1141, 'eval_samples_per_second': 42.634, 'eval_steps_per_second': 5.337}
