<a href="https://colab.research.google.com/github/SachinSelvaraj06/Steel-Defect-Detection/blob/main/Steel_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [57]:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from tqdm import tqdm

In [58]:
train_csv = '/content/steel/train.csv'
train_img_dir = '/content/steel/train_images'
test_img_dir = '/content/steel/test_images'

In [59]:
class DefectDataset(Dataset):
    def __init__(self, img_dir, csv_file=None, transform=None, train=True):
        self.img_dir = img_dir
        self.transform = transform
        self.train = train


        self.image_names = [f for f in os.listdir(img_dir) if f.endswith('.jpg') or f.endswith('.png')]

        if train:

            if isinstance(csv_file, str):
                self.data = pd.read_csv(csv_file)
            else:
                self.data = csv_file


            self.labels = {}
            for _, row in self.data.iterrows():
                img = row['ImageId']
                cls = row['ClassId']
                if img in self.image_names:
                    self.labels[img] = cls - 1


            self.image_names = [img for img in self.image_names if img in self.labels]

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        if self.train:
            label = self.labels[img_name]
            return image, label
        else:
            return image, img_name


In [60]:
transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor(),
])

In [61]:
train_dataset = DefectDataset(csv_file=train_csv, img_dir=train_img_dir, transform=transform, train=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = DefectDataset(img_dir=test_img_dir, transform=transform, train=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [62]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16*16*64, 128),
            nn.ReLU(),
            nn.Linear(128, 4)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)

In [63]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [64]:
model.train()
for images, labels in tqdm(train_loader):
    images = images.to(device)
    labels = labels.long().to(device)

    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()


100%|██████████| 1/1 [00:00<00:00,  3.14it/s]


In [65]:
model.eval()
predictions = []
with torch.no_grad():
    for images, img_names in tqdm(test_loader):
        images = images.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        preds = preds.cpu().numpy()

        for img_name, pred in zip(img_names, preds):

            predictions.append({'ImageId': img_name, 'ClassId': pred + 1})

100%|██████████| 1/1 [00:00<00:00,  2.31it/s]


In [66]:
submission_df = pd.DataFrame(predictions)
submission_df.to_csv('/content/submission.csv', index=False)
print("Submission file saved to /content/submission.csv")

Submission file saved to /content/submission.csv
