In [None]:
import timm
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt

# ======== Config ========
img_size = 224
batch_size = 32
num_classes = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ======== Transformations ========
transform = transforms.Compose([
    # transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# ======== Dataset & Dataloader ========
dataset = ImageFolder("../data/data/Aug_for_train", transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size)

# ======== ViT Model (timm pretrained) ========
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=num_classes)
model = model.to(device)

# ======== Loss & Optimizer ========
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)

# ======== Training Loop ========
def train_model(epochs=20):
    for epoch in range(epochs):
        model.train()
        running_loss = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}")
    print("Training complete.")

# ======== Evaluation Function ========
def evaluate_model():
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.numpy())
    print(classification_report(all_labels, all_preds, target_names=dataset.classes))

# ======== Run ========
train_model(epochs=20)
evaluate_model()


  from .autonotebook import tqdm as notebook_tqdm
  x = F.scaled_dot_product_attention(


Epoch 1/20, Loss: 0.2870
Epoch 2/20, Loss: 0.0655
Epoch 3/20, Loss: 0.0344
Epoch 4/20, Loss: 0.0327
Epoch 5/20, Loss: 0.0832
Epoch 6/20, Loss: 0.0241
Epoch 7/20, Loss: 0.0140
Epoch 8/20, Loss: 0.0162
Epoch 9/20, Loss: 0.0077
Epoch 10/20, Loss: 0.0071
Epoch 11/20, Loss: 0.0128
Epoch 12/20, Loss: 0.0066
Epoch 13/20, Loss: 0.0028
Epoch 14/20, Loss: 0.0023
Epoch 15/20, Loss: 0.0015
Epoch 16/20, Loss: 0.0009
Epoch 17/20, Loss: 0.0002
Epoch 18/20, Loss: 0.0001
Epoch 19/20, Loss: 0.0001
Epoch 20/20, Loss: 0.0000
Training complete.
               precision    recall  f1-score   support

Alluvial soil       0.99      0.97      0.98       108
   Black Soil       0.98      0.99      0.99       114
    Clay soil       0.99      1.00      1.00       113
     Red soil       1.00      1.00      1.00       103

     accuracy                           0.99       438
    macro avg       0.99      0.99      0.99       438
 weighted avg       0.99      0.99      0.99       438



In [None]:
# Save only the model weights (recommended)
torch.save(model.state_dict(), "../trained_model/vit_soil_classifier_weights_after_aug_20epoch.pth")
print("✅ Model weights saved.")


✅ Model weights saved.
