In [None]:
# Africa -> 0, 2103 total items
# Asia -> 1, 8852 total items
# Europe -> 2, 18117 total items
# North America -> 3, 14502 total items
# Oceania -> 4, 2296 total items
# South America -> 5, 4125 total items

In [None]:
import os
import torch
import torchvision
import sklearn

In [None]:
def count_leaf_files(directory):
    try:
        children = os.listdir(directory)
        child_dirs = [os.path.join(directory, x) for x in children]
        total = 0
        total += sum([count_leaf_files(x) for x in child_dirs])
        return total
    except NotADirectoryError as e:
        return 1
        
class GeoGuessrDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir):
        self.img_dir = img_dir
        self.all_files = [x for x in os.listdir(self.img_dir) if 'jpg' in x]

    def __len__(self):
        return len(self.all_files)

    def __getitem__(self, idx):
        filename = self.all_files[idx]
        img = torchvision.io.read_image(os.path.join(self.img_dir, filename))
        label = torch.tensor(int(filename[0]))
        if label >= 1:
            label = label - 1
        
        return img.float(), label.long()
            
        

    

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
dataset = GeoGuessrDataset('continents')
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])

In [None]:
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=8, shuffle=True, num_workers = 0)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=8, shuffle=True, num_workers = 0)

In [None]:
class ConvBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = torch.nn.Conv2d(in_channels = in_channels, out_channels=out_channels, kernel_size = 3, padding=1)
        self.maxpool = torch.nn.MaxPool2d(2)
        self.activ = torch.nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.maxpool(x)
        x = self.activ(x)
        return x

In [None]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = ConvBlock(3, 16)
        self.block2 = ConvBlock(16, 32)
        self.block3 = ConvBlock(32, 64)
        self.flatten = torch.nn.Flatten(start_dim=1)
        self.dense1 = torch.nn.Linear(64*192*82, out_features=128)
        self.dense2 = torch.nn.Linear(in_features=128, out_features=64)
        self.dense3 = torch.nn.Linear(in_features=64, out_features=6)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.relu(x)
        x = self.dense2(x)
        x = self.relu(x)
        x = self.dense3(x)
        return x

In [None]:
sample = torch.randn(3, 3, 1536, 662).to(device)

In [None]:
model = Model().to(device)

In [None]:
model.forward(sample).shape

In [None]:
def train(model, epochs):

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-1)
    loss_fn = torch.nn.CrossEntropyLoss()

    loss_hist = []
    acc_hist = []
    step_hist = []
    
    model.train()

    i = 0
    for X, y in train_dataloader:
        X, y = X.to(device), y.to(device)
        i += 1
        
        optimizer.zero_grad()
        outputs = model(X).float()
        loss = loss_fn(outputs, y)
        loss.backward()
        optimizer.step()
        
        if i % 50 == 0 and i > 50:
            y_pred = torch.argmax(outputs, dim=1).cpu().numpy()
            y = y.cpu().numpy()
            acc = sklearn.metrics.accuracy_score(y, y_pred)
            
            print(f'Epoch loss: {loss.item() / len(y)}, acc: {acc}')
            loss_hist.append(loss.item() / len(y))
            acc_hist.append(acc)
            step_hist.append(i)
            
        


        

In [None]:
train(model, 1)

In [None]:
import matplotlib.pyplot as plt
import torchvision
import torch

img = dataset[1]
print(img)

# Dacă e un tuple (imagine, etichetă), extrage doar imaginea:
if isinstance(img, (tuple, list)):
    img = img[0]

# Dacă imaginea e un tensor PyTorch, convertește-l
if isinstance(img, torch.Tensor):
    img = img.permute(1, 2, 0)  # (C, H, W) -> (H, W, C)
    img = img.numpy()

# Dacă e nevoie, ajustează forma (opțional, dacă știi forma corectă)
# img = img.reshape((662, 1536, 3))  # doar dacă e o imagine plată

# Afișează imaginea
plt.imshow(img)
plt.axis('off')
plt.show()
