In [None]:
import torch.optim as optim
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # Capas de la red
        self.fc1 = nn.Linear(28 * 28, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# Transformaciones para normalizar los datos
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Cargar los datos de entrenamiento y prueba
train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='data', train=False, download=True, transform=transform)

def train_mnist(config):
    net = Net()
    criterion = nn.NLLLoss()
    optimizer = optim.SGD(net.parameters(), lr=config["lr"])
    train_loader = DataLoader(train_data, batch_size=int(config["batch_size"]), shuffle=True)

    for e in range(10):
        running_loss = 0
        for images, labels in train_loader:
            optimizer.zero_grad()
            output = net(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        tune.report(loss=(running_loss / len(train_loader)))

def test_best_model(best_config):
    net = Net()
    criterion = nn.NLLLoss()
    optimizer = optim.SGD(net.parameters(), lr=best_config["lr"])
    train_loader = DataLoader(train_data, batch_size=int(best_config["batch_size"]), shuffle=True)

    for e in range(10):
        running_loss = 0
        for images, labels in train_loader:
            optimizer.zero_grad()
            output = net(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

    test_loss = 0
    accuracy = 0
    with torch.no_grad():
        for images, labels in test_loader:
            log_ps = net(images)
            test_loss += criterion(log_ps, labels)
            ps = torch.exp(log_ps)
            top_p, top_class = ps.topk(1, dim=1)
            equals = top_class == labels.view(*top_class.shape)
            accuracy += torch.mean(equals.type(torch.FloatTensor))
    print(f"Test loss: {test_loss/len(test_loader)}.. Test accuracy: {accuracy/len(test_loader)}")

search_space = {
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([32, 64, 128])
}

scheduler = ASHAScheduler(
    metric="loss",
    mode="min",
    max_t=10,
    grace_period=1,
    reduction_factor=2)

reporter = CLIReporter(
    metric_columns=["loss", "training_iteration"])

result = tune.run(
    train_mnist,
    resources_per_trial={"cpu": 1},
    config=search_space,
    num_samples=10,
    scheduler=scheduler,
    progress_reporter=reporter)

best_trial = result.get_best_trial("loss", "min", "last")
print("Best trial config: {}".format(best_trial.config))
test_best_model(best_trial.config)

2023-05-02 13:49:51,826	INFO worker.py:1625 -- Started a local Ray instance.
2023-05-02 13:49:55,134	INFO tune.py:218 -- Initializing Ray automatically. For cluster usage or custom Ray initialization, call `ray.init(...)` before `tune.run(...)`.


== Status ==
Current time: 2023-05-02 13:50:07 (running for 00:00:11.01)
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Logical resource usage: 1.0/8 CPUs, 0/1 GPUs
Result logdir: C:\Users\Gollo\ray_results\train_mnist_2023-05-02_13-49-56
Number of trials: 10/10 (9 PENDING, 1 RUNNING)
+-------------------------+----------+-----------------+--------------+-------------+
| Trial name              | status   | loc             |   batch_size |          lr |
|-------------------------+----------+-----------------+--------------+-------------|
| train_mnist_8358b_00000 | RUNNING  | 127.0.0.1:24308 |           64 | 0.00100316  |
| train_mnist_8358b_00001 | PENDING  |                 |           32 | 0.00942916  |
| train_mnist_8358b_00002 | PENDING  |                 |          128 | 0.00960475  |
| train_mnist_8358b_00003 | PENDING  |                 |           32 | 0.0261562   |
| train_mnist_8358b_00004 | PENDING  |  



== Status ==
Current time: 2023-05-02 13:50:51 (running for 00:00:55.04)
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Logical resource usage: 6.0/8 CPUs, 0/1 GPUs
Result logdir: C:\Users\Gollo\ray_results\train_mnist_2023-05-02_13-49-56
Number of trials: 10/10 (4 PENDING, 6 RUNNING)
+-------------------------+----------+-----------------+--------------+-------------+
| Trial name              | status   | loc             |   batch_size |          lr |
|-------------------------+----------+-----------------+--------------+-------------|
| train_mnist_8358b_00000 | RUNNING  | 127.0.0.1:24308 |           64 | 0.00100316  |
| train_mnist_8358b_00001 | RUNNING  | 127.0.0.1:4828  |           32 | 0.00942916  |
| train_mnist_8358b_00002 | RUNNING  | 127.0.0.1:6996  |          128 | 0.00960475  |
| train_mnist_8358b_00003 | RUNNING  | 127.0.0.1:13708 |           32 | 0.0261562   |
| train_mnist_8358b_00004 | RUNNING  | 1



== Status ==
Current time: 2023-05-02 13:51:04 (running for 00:01:07.84)
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Logical resource usage: 7.0/8 CPUs, 0/1 GPUs
Result logdir: C:\Users\Gollo\ray_results\train_mnist_2023-05-02_13-49-56
Number of trials: 10/10 (3 PENDING, 7 RUNNING)
+-------------------------+----------+-----------------+--------------+-------------+
| Trial name              | status   | loc             |   batch_size |          lr |
|-------------------------+----------+-----------------+--------------+-------------|
| train_mnist_8358b_00000 | RUNNING  | 127.0.0.1:24308 |           64 | 0.00100316  |
| train_mnist_8358b_00001 | RUNNING  | 127.0.0.1:4828  |           32 | 0.00942916  |
| train_mnist_8358b_00002 | RUNNING  | 127.0.0.1:6996  |          128 | 0.00960475  |
| train_mnist_8358b_00003 | RUNNING  | 127.0.0.1:13708 |           32 | 0.0261562   |
| train_mnist_8358b_00004 | RUNNING  | 1



== Status ==
Current time: 2023-05-02 13:51:23 (running for 00:01:26.28)
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Logical resource usage: 8.0/8 CPUs, 0/1 GPUs
Result logdir: C:\Users\Gollo\ray_results\train_mnist_2023-05-02_13-49-56
Number of trials: 10/10 (2 PENDING, 8 RUNNING)
+-------------------------+----------+-----------------+--------------+-------------+
| Trial name              | status   | loc             |   batch_size |          lr |
|-------------------------+----------+-----------------+--------------+-------------|
| train_mnist_8358b_00000 | RUNNING  | 127.0.0.1:24308 |           64 | 0.00100316  |
| train_mnist_8358b_00001 | RUNNING  | 127.0.0.1:4828  |           32 | 0.00942916  |
| train_mnist_8358b_00002 | RUNNING  | 127.0.0.1:6996  |          128 | 0.00960475  |
| train_mnist_8358b_00003 | RUNNING  | 127.0.0.1:13708 |           32 | 0.0261562   |
| train_mnist_8358b_00004 | RUNNING  | 1

Trial name,date,done,hostname,iterations_since_restore,loss,node_ip,pid,time_since_restore,time_this_iter_s,time_total_s,timestamp,training_iteration,trial_id
train_mnist_8358b_00000,2023-05-02_13-50-40,True,DESKTOP-TV0ASU4,1,1.48896,127.0.0.1,24308,26.7542,26.7542,26.7542,1683057040,1,8358b_00000
train_mnist_8358b_00001,2023-05-02_13-54-08,True,DESKTOP-TV0ASU4,4,0.147626,127.0.0.1,4828,227.093,23.0446,227.093,1683057248,4,8358b_00001
train_mnist_8358b_00002,2023-05-02_13-55-52,True,DESKTOP-TV0ASU4,10,0.177105,127.0.0.1,6996,322.415,19.1261,322.415,1683057352,10,8358b_00002
train_mnist_8358b_00003,2023-05-02_13-56-15,True,DESKTOP-TV0ASU4,10,0.0242566,127.0.0.1,13708,336.154,18.5594,336.154,1683057375,10,8358b_00003
train_mnist_8358b_00004,2023-05-02_13-56-16,True,DESKTOP-TV0ASU4,10,0.0470406,127.0.0.1,20256,325.815,18.5009,325.815,1683057376,10,8358b_00004
train_mnist_8358b_00005,2023-05-02_13-51-24,True,DESKTOP-TV0ASU4,1,1.98898,127.0.0.1,17492,21.1793,21.1793,21.1793,1683057084,1,8358b_00005
train_mnist_8358b_00006,2023-05-02_13-51-40,True,DESKTOP-TV0ASU4,1,2.25376,127.0.0.1,16348,18.3304,18.3304,18.3304,1683057100,1,8358b_00006
train_mnist_8358b_00007,2023-05-02_13-53-03,True,DESKTOP-TV0ASU4,1,2.24927,127.0.0.1,23972,37.3628,37.3628,37.3628,1683057183,1,8358b_00007
train_mnist_8358b_00008,2023-05-02_13-56-31,True,DESKTOP-TV0ASU4,10,0.0233852,127.0.0.1,17492,244.391,14.3188,244.391,1683057391,10,8358b_00008
train_mnist_8358b_00009,2023-05-02_13-53-48,True,DESKTOP-TV0ASU4,2,0.231208,127.0.0.1,16348,81.1515,38.6416,81.1515,1683057228,2,8358b_00009


== Status ==
Current time: 2023-05-02 13:52:26 (running for 00:02:29.81)
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: -0.7589480167767132
Logical resource usage: 8.0/8 CPUs, 0/1 GPUs
Result logdir: C:\Users\Gollo\ray_results\train_mnist_2023-05-02_13-49-56
Number of trials: 10/10 (2 PENDING, 8 RUNNING)
+-------------------------+----------+-----------------+--------------+-------------+----------+----------------------+
| Trial name              | status   | loc             |   batch_size |          lr |     loss |   training_iteration |
|-------------------------+----------+-----------------+--------------+-------------+----------+----------------------|
| train_mnist_8358b_00000 | RUNNING  | 127.0.0.1:24308 |           64 | 0.00100316  |          |                      |
| train_mnist_8358b_00001 | RUNNING  | 127.0.0.1:4828  |           32 | 0.00942916  |          |                      |
| train_mnist_8358b_00002 |

== Status ==
Current time: 2023-05-02 13:52:47 (running for 00:02:50.78)
Using AsyncHyperBand: num_stopped=3
Bracket: Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: -0.7589480167767132
Logical resource usage: 7.0/8 CPUs, 0/1 GPUs
Result logdir: C:\Users\Gollo\ray_results\train_mnist_2023-05-02_13-49-56
Number of trials: 10/10 (7 RUNNING, 3 TERMINATED)
+-------------------------+------------+-----------------+--------------+-------------+----------+----------------------+
| Trial name              | status     | loc             |   batch_size |          lr |     loss |   training_iteration |
|-------------------------+------------+-----------------+--------------+-------------+----------+----------------------|
| train_mnist_8358b_00001 | RUNNING    | 127.0.0.1:4828  |           32 | 0.00942916  | 0.43131  |                    1 |
| train_mnist_8358b_00002 | RUNNING    | 127.0.0.1:6996  |          128 | 0.00960475  | 0.758948 |                    1 |
| train_mnist_

== Status ==
Current time: 2023-05-02 13:53:09 (running for 00:03:12.80)
Using AsyncHyperBand: num_stopped=4
Bracket: Iter 8.000: None | Iter 4.000: None | Iter 2.000: -0.21299497713446616 | Iter 1.000: -0.7589480167767132
Logical resource usage: 6.0/8 CPUs, 0/1 GPUs
Result logdir: C:\Users\Gollo\ray_results\train_mnist_2023-05-02_13-49-56
Number of trials: 10/10 (6 RUNNING, 4 TERMINATED)
+-------------------------+------------+-----------------+--------------+-------------+----------+----------------------+
| Trial name              | status     | loc             |   batch_size |          lr |     loss |   training_iteration |
|-------------------------+------------+-----------------+--------------+-------------+----------+----------------------|
| train_mnist_8358b_00001 | RUNNING    | 127.0.0.1:4828  |           32 | 0.00942916  | 0.232847 |                    2 |
| train_mnist_8358b_00002 | RUNNING    | 127.0.0.1:6996  |          128 | 0.00960475  | 0.358168 |                    2 

== Status ==
Current time: 2023-05-02 13:53:30 (running for 00:03:33.65)
Using AsyncHyperBand: num_stopped=4
Bracket: Iter 8.000: None | Iter 4.000: None | Iter 2.000: -0.21299497713446616 | Iter 1.000: -0.5965306136340263
Logical resource usage: 6.0/8 CPUs, 0/1 GPUs
Result logdir: C:\Users\Gollo\ray_results\train_mnist_2023-05-02_13-49-56
Number of trials: 10/10 (6 RUNNING, 4 TERMINATED)
+-------------------------+------------+-----------------+--------------+-------------+----------+----------------------+
| Trial name              | status     | loc             |   batch_size |          lr |     loss |   training_iteration |
|-------------------------+------------+-----------------+--------------+-------------+----------+----------------------|
| train_mnist_8358b_00001 | RUNNING    | 127.0.0.1:4828  |           32 | 0.00942916  | 0.232847 |                    2 |
| train_mnist_8358b_00002 | RUNNING    | 127.0.0.1:6996  |          128 | 0.00960475  | 0.358168 |                    2 

== Status ==
Current time: 2023-05-02 13:53:53 (running for 00:03:56.39)
Using AsyncHyperBand: num_stopped=5
Bracket: Iter 8.000: None | Iter 4.000: None | Iter 2.000: -0.21217550440977018 | Iter 1.000: -0.5965306136340263
Logical resource usage: 5.0/8 CPUs, 0/1 GPUs
Result logdir: C:\Users\Gollo\ray_results\train_mnist_2023-05-02_13-49-56
Number of trials: 10/10 (5 RUNNING, 5 TERMINATED)
+-------------------------+------------+-----------------+--------------+-------------+-----------+----------------------+
| Trial name              | status     | loc             |   batch_size |          lr |      loss |   training_iteration |
|-------------------------+------------+-----------------+--------------+-------------+-----------+----------------------|
| train_mnist_8358b_00001 | RUNNING    | 127.0.0.1:4828  |           32 | 0.00942916  | 0.180569  |                    3 |
| train_mnist_8358b_00002 | RUNNING    | 127.0.0.1:6996  |          128 | 0.00960475  | 0.304188  |                 

== Status ==
Current time: 2023-05-02 13:54:21 (running for 00:04:25.05)
Using AsyncHyperBand: num_stopped=6
Bracket: Iter 8.000: None | Iter 4.000: -0.13018121738409003 | Iter 2.000: -0.21217550440977018 | Iter 1.000: -0.5965306136340263
Logical resource usage: 4.0/8 CPUs, 0/1 GPUs
Result logdir: C:\Users\Gollo\ray_results\train_mnist_2023-05-02_13-49-56
Number of trials: 10/10 (4 RUNNING, 6 TERMINATED)
+-------------------------+------------+-----------------+--------------+-------------+-----------+----------------------+
| Trial name              | status     | loc             |   batch_size |          lr |      loss |   training_iteration |
|-------------------------+------------+-----------------+--------------+-------------+-----------+----------------------|
| train_mnist_8358b_00002 | RUNNING    | 127.0.0.1:6996  |          128 | 0.00960475  | 0.249702  |                    5 |
| train_mnist_8358b_00003 | RUNNING    | 127.0.0.1:13708 |           32 | 0.0261562   | 0.0748108 | 

== Status ==
Current time: 2023-05-02 13:54:44 (running for 00:04:47.50)
Using AsyncHyperBand: num_stopped=6
Bracket: Iter 8.000: None | Iter 4.000: -0.11273613734493652 | Iter 2.000: -0.21217550440977018 | Iter 1.000: -0.5965306136340263
Logical resource usage: 4.0/8 CPUs, 0/1 GPUs
Result logdir: C:\Users\Gollo\ray_results\train_mnist_2023-05-02_13-49-56
Number of trials: 10/10 (4 RUNNING, 6 TERMINATED)
+-------------------------+------------+-----------------+--------------+-------------+-----------+----------------------+
| Trial name              | status     | loc             |   batch_size |          lr |      loss |   training_iteration |
|-------------------------+------------+-----------------+--------------+-------------+-----------+----------------------|
| train_mnist_8358b_00002 | RUNNING    | 127.0.0.1:6996  |          128 | 0.00960475  | 0.230623  |                    6 |
| train_mnist_8358b_00003 | RUNNING    | 127.0.0.1:13708 |           32 | 0.0261562   | 0.0594775 | 

== Status ==
Current time: 2023-05-02 13:55:08 (running for 00:05:11.63)
Using AsyncHyperBand: num_stopped=6
Bracket: Iter 8.000: None | Iter 4.000: -0.11273613734493652 | Iter 2.000: -0.21217550440977018 | Iter 1.000: -0.5965306136340263
Logical resource usage: 4.0/8 CPUs, 0/1 GPUs
Result logdir: C:\Users\Gollo\ray_results\train_mnist_2023-05-02_13-49-56
Number of trials: 10/10 (4 RUNNING, 6 TERMINATED)
+-------------------------+------------+-----------------+--------------+-------------+-----------+----------------------+
| Trial name              | status     | loc             |   batch_size |          lr |      loss |   training_iteration |
|-------------------------+------------+-----------------+--------------+-------------+-----------+----------------------|
| train_mnist_8358b_00002 | RUNNING    | 127.0.0.1:6996  |          128 | 0.00960475  | 0.214085  |                    7 |
| train_mnist_8358b_00003 | RUNNING    | 127.0.0.1:13708 |           32 | 0.0261562   | 0.048355  | 

== Status ==
Current time: 2023-05-02 13:55:31 (running for 00:05:34.99)
Using AsyncHyperBand: num_stopped=6
Bracket: Iter 8.000: -0.2002817982358973 | Iter 4.000: -0.11273613734493652 | Iter 2.000: -0.21217550440977018 | Iter 1.000: -0.5965306136340263
Logical resource usage: 4.0/8 CPUs, 0/1 GPUs
Result logdir: C:\Users\Gollo\ray_results\train_mnist_2023-05-02_13-49-56
Number of trials: 10/10 (4 RUNNING, 6 TERMINATED)
+-------------------------+------------+-----------------+--------------+-------------+-----------+----------------------+
| Trial name              | status     | loc             |   batch_size |          lr |      loss |   training_iteration |
|-------------------------+------------+-----------------+--------------+-------------+-----------+----------------------|
| train_mnist_8358b_00002 | RUNNING    | 127.0.0.1:6996  |          128 | 0.00960475  | 0.200282  |                    8 |
| train_mnist_8358b_00003 | RUNNING    | 127.0.0.1:13708 |           32 | 0.0261562  

== Status ==
Current time: 2023-05-02 13:55:57 (running for 00:06:00.18)
Using AsyncHyperBand: num_stopped=7
Bracket: Iter 8.000: -0.059774902488663795 | Iter 4.000: -0.11273613734493652 | Iter 2.000: -0.21217550440977018 | Iter 1.000: -0.5965306136340263
Logical resource usage: 3.0/8 CPUs, 0/1 GPUs
Result logdir: C:\Users\Gollo\ray_results\train_mnist_2023-05-02_13-49-56
Number of trials: 10/10 (3 RUNNING, 7 TERMINATED)
+-------------------------+------------+-----------------+--------------+-------------+-----------+----------------------+
| Trial name              | status     | loc             |   batch_size |          lr |      loss |   training_iteration |
|-------------------------+------------+-----------------+--------------+-------------+-----------+----------------------|
| train_mnist_8358b_00003 | RUNNING    | 127.0.0.1:13708 |           32 | 0.0261562   | 0.0286485 |                    9 |
| train_mnist_8358b_00004 | RUNNING    | 127.0.0.1:20256 |           32 | 0.014379 

== Status ==
Current time: 2023-05-02 13:56:22 (running for 00:06:25.17)
Using AsyncHyperBand: num_stopped=9
Bracket: Iter 8.000: -0.04677579638940903 | Iter 4.000: -0.11273613734493652 | Iter 2.000: -0.21217550440977018 | Iter 1.000: -0.5965306136340263
Logical resource usage: 1.0/8 CPUs, 0/1 GPUs
Result logdir: C:\Users\Gollo\ray_results\train_mnist_2023-05-02_13-49-56
Number of trials: 10/10 (1 RUNNING, 9 TERMINATED)
+-------------------------+------------+-----------------+--------------+-------------+-----------+----------------------+
| Trial name              | status     | loc             |   batch_size |          lr |      loss |   training_iteration |
|-------------------------+------------+-----------------+--------------+-------------+-----------+----------------------|
| train_mnist_8358b_00008 | RUNNING    | 127.0.0.1:17492 |           32 | 0.0272307   | 0.027699  |                    9 |
| train_mnist_8358b_00000 | TERMINATED | 127.0.0.1:24308 |           64 | 0.00100316

2023-05-02 13:56:31,311	INFO tune.py:945 -- Total run time: 396.18 seconds (394.45 seconds for the tuning loop).


== Status ==
Current time: 2023-05-02 13:56:31 (running for 00:06:34.46)
Using AsyncHyperBand: num_stopped=10
Bracket: Iter 8.000: -0.04677579638940903 | Iter 4.000: -0.11273613734493652 | Iter 2.000: -0.21217550440977018 | Iter 1.000: -0.5965306136340263
Logical resource usage: 0/8 CPUs, 0/1 GPUs
Result logdir: C:\Users\Gollo\ray_results\train_mnist_2023-05-02_13-49-56
Number of trials: 10/10 (10 TERMINATED)
+-------------------------+------------+-----------------+--------------+-------------+-----------+----------------------+
| Trial name              | status     | loc             |   batch_size |          lr |      loss |   training_iteration |
|-------------------------+------------+-----------------+--------------+-------------+-----------+----------------------|
| train_mnist_8358b_00000 | TERMINATED | 127.0.0.1:24308 |           64 | 0.00100316  | 1.48896   |                    1 |
| train_mnist_8358b_00001 | TERMINATED | 127.0.0.1:4828  |           32 | 0.00942916  | 0.14762