Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

i have no GPU on laptop i want Run GAN model though colab #4072

Closed
ibad321 opened this issue Aug 23, 2024 · 1 comment
Closed

i have no GPU on laptop i want Run GAN model though colab #4072

ibad321 opened this issue Aug 23, 2024 · 1 comment

Comments

@ibad321
Copy link

ibad321 commented Aug 23, 2024

What is your question?

i have no GPU on laptop i want Run GAN model though colab Gpu with flower framework mean GAN with Fedrating learning Any one Guide me i want eg train GAN on CIFAR-100 dataset through flower framework any one guide me.
Thanks

@WilliamLindskog
Copy link
Contributor

Hi

Thanks for raising this. You can use this example to use for GAN. Your client function could look something like:

from flwr.client import ClientApp, NumPyClient

class FlowerGANClient(NumPyClient):
    def __init__(self, generator, discriminator, dataloader, z_dim):
        self.generator = generator
        self.discriminator = discriminator
        self.dataloader = dataloader
        self.z_dim = z_dim

    def get_parameters(self, config):
        return [
            *[p.cpu().numpy() for p in self.generator.parameters()],
            *[p.cpu().numpy() for p in self.discriminator.parameters()],
        ]

    def set_parameters(self, parameters):
        gen_params = parameters[: len(list(self.generator.parameters()))]
        disc_params = parameters[len(list(self.generator.parameters())) :]
        for param, value in zip(self.generator.parameters(), gen_params):
            param.data = torch.tensor(value)
        for param, value in zip(self.discriminator.parameters(), disc_params):
            param.data = torch.tensor(value)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        self.train_gan()
        return self.get_parameters(config), len(self.dataloader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        return 0.0, len(self.dataloader.dataset), {}

    def train_gan(self):
        optimizer_gen = optim.Adam(self.generator.parameters(), lr=0.0001)
        optimizer_disc = optim.Adam(self.discriminator.parameters(), lr=0.0001)
        criterion = nn.BCELoss()

        for epoch in range(1):
            for real, _ in self.dataloader:
                real = real.view(-1, IMG_DIM).to("cuda")
                batch_size = real.size(0)

                # Train Discriminator
                noise = torch.randn(batch_size, self.z_dim).to("cuda")
                fake = self.generator(noise)
                disc_real = self.discriminator(real)
                disc_fake = self.discriminator(fake.detach())
                loss_disc = criterion(disc_real, torch.ones_like(disc_real)) + criterion(
                    disc_fake, torch.zeros_like(disc_fake)
                )
                optimizer_disc.zero_grad()
                loss_disc.backward()
                optimizer_disc.step()

                # Train Generator
                output = self.discriminator(fake)
                loss_gen = criterion(output, torch.ones_like(output))
                optimizer_gen.zero_grad()
                loss_gen.backward()
                optimizer_gen.step()

# Define the client function
def client_fn(context):
    generator = Generator(z_dim=Z_DIM, img_dim=IMG_DIM).to("cuda")
    discriminator = Discriminator(img_dim=IMG_DIM).to("cuda")
    return FlowerGANClient(generator, discriminator, dataloader, Z_DIM).to_client()

# Create the Flower ClientApp
client_app = ClientApp(client_fn=client_fn) 

but make sure you have access to GPU so that you can use cuda call.

Closing this issue for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants