In [1]:
from transformers import ViTFeatureExtractor, ViTForImageClassification
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms
from datasets import Dataset
import torch
from transformers import TrainerCallback
import gc

# Set up the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define image transformations (resize, normalize, etc.)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resizing images to match ViT input size
    transforms.ToTensor(),  # Convert images to tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize based on ImageNet
])

# Load dataset from folders using torchvision's ImageFolder
train_dataset = ImageFolder(root='C:\\Users\\akshg\\Downloads\\traning', transform=transform)
test_dataset = ImageFolder(root='C:\\Users\\akshg\\Downloads\\test', transform=transform)

# Convert to Hugging Face Dataset format
def convert_to_hf_dataset(img_dataset):
    # Get images and labels from ImageFolder dataset
    images, labels = zip(*[(image, label) for image, label in img_dataset])
    return Dataset.from_dict({"image": list(images), "label": list(labels)})

hf_train_dataset = convert_to_hf_dataset(train_dataset)
hf_test_dataset = convert_to_hf_dataset(test_dataset)

# Initialize the feature extractor and model
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=len(train_dataset.classes))

# Move the model to the appropriate device
model.to(device)

# Preprocess the images (extract features)
def preprocess_images(examples):
    images = [image.numpy() for image in examples['image']]
    # Extract features using the feature extractor
    features = feature_extractor(images, return_tensors='pt')
    # Move tensors to the same device as the model
    return {k: v.to(device) for k, v in features.items()}


# Apply preprocessing to datasets
hf_train_dataset = hf_train_dataset.map(preprocess_images, batched=True)
hf_test_dataset = hf_test_dataset.map(preprocess_images, batched=True)



# Set training arguments
training_args = TrainingArguments(
    output_dir='./results',
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=3,
    evaluation_strategy="epoch",
    # Enable mixed precision training if desired
    fp16=True if torch.cuda.is_available() else False
)


class MemoryManagementCallback(TrainerCallback):
    def on_epoch_end(self, args, state, control, **kwargs):
        # Clear memory after every epoch
        print(f"Epoch {state.epoch} finished. Clearing memory...")
        torch.cuda.empty_cache()  # Clear GPU cache
        gc.collect()  # Clear unused objects from memory


# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=hf_train_dataset,
    eval_dataset=hf_test_dataset,
    callbacks=[MemoryManagementCallback()]  # Add the callback here
)

# Train the model
trainer.train()
print('ok')


  torch.utils._pytree._register_pytree_node(


Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "D:\anaconda\Lib\site-packages\IPython\core\interactiveshell.py", line 3526, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "C:\Users\akshg\AppData\Local\Temp\ipykernel_5264\2767822615.py", line 31, in <module>
    hf_train_dataset = convert_to_hf_dataset(train_dataset)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\akshg\AppData\Local\Temp\ipykernel_5264\2767822615.py", line 28, in convert_to_hf_dataset
    images, labels = zip(*[(image, label) for image, label in img_dataset])
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\akshg\AppData\Local\Temp\ipykernel_5264\2767822615.py", line 28, in <listcomp>
    images, labels = zip(*[(image, label) for image, label in img_dataset])
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\akshg\AppData\Roaming\Python\Python311\site-packages\torchvision\datasets\folde

In [2]:
import torch
import torchvision
print(torch.__version__)
print(torchvision.__version__)


2.4.1+cpu
0.19.1+cpu


In [2]:
import torch
torch.cuda.is_available()

True

In [None]:
D:\anaconda\Lib\site-packages\transformers\utils\generic.py:260: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.
  torch.utils._pytree._register_pytree_node(