In [1]:
from transformers import SwinConfig, SwinModel

In [None]:
configuration = SwinConfig(num_labels = 10)

In [None]:
model = SwinModel(configuration)

In [None]:
from transformers import AutoImageProcessor, SwinForImageClassification
import torch

In [None]:
image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")

In [None]:
from datasets import load_dataset

In [None]:
data = load_dataset("C:/Users/xdany/Desktop/for project", data_dir="Fast Food Classification V2")

In [None]:
data['train'][0]

In [None]:
def transforms(examples):
    examples["pixel_values"] = [image.convert("RGB").resize((250,250)) for image in examples["image"]]
    return examples

In [None]:
data = data.map(transforms, remove_columns=["image"], batched=True, batch_size=500)
data['train'][0]

In [None]:
import pickle

In [None]:
with open('fastfoodData.pkl', 'wb') as file:
    pickle.dump(data, file)

In [None]:
with open('fastfoodData.pkl', 'rb') as file:
    data = pickle.load(file)

In [None]:
def process_example(example):
    inputs = image_processor(example['pixel_values'], return_tensors='pt')
    inputs['labels'] = example['label']
    return inputs

In [None]:
process_example(data['train'][0])

In [None]:
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = image_processor([x for x in example_batch['pixel_values']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['label']
    return inputs

In [None]:
data = data.with_transform(transform)

In [None]:
import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

In [None]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

In [None]:
labels = data['train'].features['label'].names

In [None]:
model = SwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224",
                                                    num_labels=len(labels),
                                                    id2label={str(i): c for i, c in enumerate(labels)},
                                                    label2id={c: str(i) for i, c in enumerate(labels)},
                                                    ignore_mismatched_sizes=True
                                                  )

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="E:\Swin Model FineTuned\Model",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=3,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=data["train"],
    eval_dataset=data["validation"],
    tokenizer=image_processor,
)

In [None]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

In [None]:
metrics = trainer.evaluate(data['test'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

In [None]:
trainer.save_model("E:/some model/some")